mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
pass step count and step index to diffusion step func (#2342)
This commit is contained in:
@@ -391,7 +391,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(batched_t, latents, conditioning_data,
|
||||
i, additional_guidance=additional_guidance)
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
|
||||
@@ -410,7 +412,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
@torch.inference_mode()
|
||||
def step(self, t: torch.Tensor, latents: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
step_index:int | None = None, additional_guidance: List[Callable] = None):
|
||||
step_index:int, total_step_count:int,
|
||||
additional_guidance: List[Callable] = None):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
@@ -427,6 +430,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
|
||||
conditioning_data.guidance_scale,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
threshold=conditioning_data.threshold
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user