Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 89 additions & 14 deletions src/diffusers/schedulers/scheduling_dpm_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def rescale_zero_terminal_snr(alphas_cumprod):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)


Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Expand Down Expand Up @@ -175,11 +174,14 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Choose from
`leading`, `linspace` or `trailing`.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
snr_shift_scale (`float`, defaults to 3.0):
Shift scale for SNR.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -191,15 +193,15 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.0120,
beta_schedule: str = "scaled_linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
timestep_spacing: Literal["leading", "linspace", "trailing"] = "leading",
rescale_betas_zero_snr: bool = False,
snr_shift_scale: float = 3.0,
):
Expand All @@ -209,7 +211,15 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float64,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
Expand Down Expand Up @@ -266,13 +276,20 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
"""
return sample

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int,
device: Optional[Union[str, torch.device]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None` (the default), the timesteps are not
moved.
"""

if num_inference_steps > self.config.num_train_timesteps:
Expand Down Expand Up @@ -311,7 +328,27 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

self.timesteps = torch.from_numpy(timesteps).to(device)

def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
def get_variables(
self,
alpha_prod_t: torch.Tensor,
alpha_prod_t_prev: torch.Tensor,
alpha_prod_t_back: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
"""
Compute the variables used for DPM-Solver++ (2M) referencing the original implementation.

Args:
alpha_prod_t (`torch.Tensor`):
The cumulative product of alphas at the current timestep.
alpha_prod_t_prev (`torch.Tensor`):
The cumulative product of alphas at the previous timestep.
alpha_prod_t_back (`torch.Tensor`, *optional*):
The cumulative product of alphas at the timestep before the previous timestep.

Returns:
`tuple`:
A tuple containing the variables `h`, `r`, `lamb`, `lamb_next`.
"""
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
h = lamb_next - lamb
Expand All @@ -324,7 +361,36 @@ def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None)
else:
return h, None, lamb, lamb_next

def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
def get_mult(
self,
h: torch.Tensor,
r: Optional[torch.Tensor],
alpha_prod_t: torch.Tensor,
alpha_prod_t_prev: torch.Tensor,
alpha_prod_t_back: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
"""
Compute the multipliers for the previous sample and the predicted original sample.

Args:
h (`torch.Tensor`):
The log-SNR difference.
r (`torch.Tensor`):
The ratio of log-SNR differences.
alpha_prod_t (`torch.Tensor`):
The cumulative product of alphas at the current timestep.
alpha_prod_t_prev (`torch.Tensor`):
The cumulative product of alphas at the previous timestep.
alpha_prod_t_back (`torch.Tensor`, *optional*):
The cumulative product of alphas at the timestep before the previous timestep.

Returns:
`tuple`:
A tuple containing the multipliers.
"""
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5

Expand All @@ -338,13 +404,13 @@ def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
def step(
self,
model_output: torch.Tensor,
old_pred_original_sample: torch.Tensor,
old_pred_original_sample: Optional[torch.Tensor],
timestep: int,
timestep_back: int,
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = False,
) -> Union[DDIMSchedulerOutput, Tuple]:
Expand All @@ -355,8 +421,12 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
old_pred_original_sample (`torch.Tensor`):
The predicted original sample from the previous timestep.
timestep (`int`):
The current discrete timestep in the diffusion chain.
timestep_back (`int`):
The timestep to look back to.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
Expand Down Expand Up @@ -436,7 +506,12 @@ def step(
return prev_sample, pred_original_sample
else:
denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
noise = randn_tensor(
sample.shape,
generator=generator,
device=sample.device,
dtype=sample.dtype,
)
x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise

prev_sample = x_advanced
Expand Down Expand Up @@ -524,5 +599,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity

def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps
Loading