From 2d6f48821da91a47f322d30765bfead56d76d677 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 6 Dec 2023 12:25:06 -0600 Subject: [PATCH] Fix SharkEulerDiscrete (#2022) --- .../src/schedulers/shark_eulerdiscrete.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py index 3c25dc40..5e9040c5 100644 --- a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py +++ b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py @@ -48,7 +48,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler): steps_offset, ) # TODO: make it dynamic so we dont have to worry about batch size - self.batch_size = None + self.batch_size = 1 def compile(self, batch_size=1): SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers" @@ -171,8 +171,9 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler): _import(self) def scale_model_input(self, sample, timestep): - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] + if self.step_index is None: + self._init_step_index(timestep) + sigma = self.sigmas[self.step_index] return self.scaling_model( "forward", ( @@ -213,21 +214,25 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler): else noise_pred ) - if gamma > 0: - noise = randn_tensor( - torch.Size(noise_pred.shape), - dtype=torch.float16, - device="cpu", - generator=generator, - ) + noise = randn_tensor( + torch.Size(noise_pred.shape), + dtype=torch.float16, + device="cpu", + generator=generator, + ) - eps = noise * s_noise + eps = noise * s_noise + + if gamma > 0: latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5 if self.config.prediction_type == "v_prediction": sigma_hat = sigma dt = self.sigmas[self.step_index + 1] - sigma_hat + + self._step_index += 1 + return self.step_model( "forward", (