fix(app): step callbacks for SD, FLUX, MultiDiffusion

Each of these was a bit off:
- The SD callback started at `-1` and ended at `i`. Combined w/ the weird math on the previous `calc_percentage` util, this caused the progress bar to never finish.
- The MultiDiffusion callback had the same problems as SD.
- The FLUX callback didn't emit a pre-denoising step 0 image. It also reported total_steps as 1 higher than the actual step count.

Each of these now emit the expected events to the frontend:
- The initial latents at 0%
- Progress at each step, ending at 100%
This commit is contained in:
psychedelicious
2024-09-21 19:17:09 +10:00
committed by Kent Keirsey
parent a6f93d3862
commit dc10197615
3 changed files with 17 additions and 6 deletions

View File

@@ -22,7 +22,18 @@ def denoise(
guidance: float,
traj_guidance_extension: TrajectoryGuidanceExtension | None, # noqa: F821
):
step = 0
# step 0 is the initial state
total_steps = len(timesteps) - 1
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=img,
),
)
step = 1
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
@@ -49,7 +60,7 @@ def denoise(
PipelineIntermediateState(
step=step,
order=1,
total_steps=len(timesteps),
total_steps=total_steps,
timestep=int(t_curr),
latents=preview_img,
),

View File

@@ -366,7 +366,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with attn_ctx:
callback(
PipelineIntermediateState(
step=-1,
step=0, # initial latents
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
@@ -395,7 +395,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback(
PipelineIntermediateState(
step=i,
step=i + 1, # final latents
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),

View File

@@ -81,7 +81,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
callback(
PipelineIntermediateState(
step=-1,
step=0,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=self.scheduler.config.num_train_timesteps,
@@ -182,7 +182,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
callback(
PipelineIntermediateState(
step=i,
step=i + 1,
order=self.scheduler.order,
total_steps=len(timesteps),
timestep=int(t),