Switch to sequential CFG for CogView4 (for now, until I sort out the padding).

This commit is contained in:
Ryan Dick
2025-03-06 20:39:02 +00:00
committed by psychedelicious
parent 321c2d358c
commit 3166b5d2ea

View File

@@ -166,15 +166,13 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
device=device,
)
# TODO(ryand): Support both sequential and batched CFG inference.
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
# Prepare misc. conditioning variables.
# TODO(ryand): We could expose these as params (like with SDXL). But, we should experiment to see if they are
# useful first.
original_size = torch.tensor([(self.height, self.width)], dtype=prompt_embeds.dtype, device=device)
target_size = torch.tensor([(self.height, self.width)], dtype=prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=device)
original_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
target_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([(0, 0)], dtype=pos_prompt_embeds.dtype, device=device)
# Prepare the timestep schedule.
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (
@@ -218,15 +216,14 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Denoising loop
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
timestep = torch.tensor([t_curr * 1000], device=device).expand(latents.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
# TODO(ryand): Support both sequential and batched CFG inference.
noise_pred_cond = transformer(
hidden_states=latents,
encoder_hidden_states=pos_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -236,8 +233,19 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Apply CFG.
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
noise_pred_uncond = transformer(
hidden_states=latents,
encoder_hidden_states=neg_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype