Fix SharkEulerDiscrete (#2022)

This commit is contained in:
Ean Garvey
2023-12-06 12:25:06 -06:00
committed by GitHub
parent c74b55f24e
commit 2d6f48821d

View File

@@ -48,7 +48,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
steps_offset,
)
# TODO: make it dynamic so we dont have to worry about batch size
self.batch_size = None
self.batch_size = 1
def compile(self, batch_size=1):
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
@@ -171,8 +171,9 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
_import(self)
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
return self.scaling_model(
"forward",
(
@@ -213,21 +214,25 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
else noise_pred
)
if gamma > 0:
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
noise = randn_tensor(
torch.Size(noise_pred.shape),
dtype=torch.float16,
device="cpu",
generator=generator,
)
eps = noise * s_noise
eps = noise * s_noise
if gamma > 0:
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
if self.config.prediction_type == "v_prediction":
sigma_hat = sigma
dt = self.sigmas[self.step_index + 1] - sigma_hat
self._step_index += 1
return self.step_model(
"forward",
(