mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:18:03 -05:00
Switch to sequential CFG for CogView4 (for now, until I sort out the padding).
This commit is contained in:
committed by
psychedelicious
parent
321c2d358c
commit
3166b5d2ea
@@ -166,15 +166,13 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
|
||||
# Prepare misc. conditioning variables.
|
||||
# TODO(ryand): We could expose these as params (like with SDXL). But, we should experiment to see if they are
|
||||
# useful first.
|
||||
original_size = torch.tensor([(self.height, self.width)], dtype=prompt_embeds.dtype, device=device)
|
||||
target_size = torch.tensor([(self.height, self.width)], dtype=prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=device)
|
||||
original_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
|
||||
target_size = torch.tensor([(self.height, self.width)], dtype=pos_prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([(0, 0)], dtype=pos_prompt_embeds.dtype, device=device)
|
||||
|
||||
# Prepare the timestep schedule.
|
||||
image_seq_len = ((self.height // LATENT_SCALE_FACTOR) * (self.width // LATENT_SCALE_FACTOR)) // (
|
||||
@@ -218,15 +216,14 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
# Denoising loop
|
||||
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
# Expand the latents if we are doing CFG.
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# Expand the timestep to match the latent model input.
|
||||
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
|
||||
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
|
||||
timestep = torch.tensor([t_curr * 1000], device=device).expand(latents.shape[0])
|
||||
|
||||
noise_pred = transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
noise_pred_cond = transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=pos_prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
@@ -236,8 +233,19 @@ class CogView4DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
# Apply CFG.
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
||||
noise_pred_uncond = transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=neg_prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
|
||||
Reference in New Issue
Block a user