Remove wgpu specific checks from stable diffusion example (#7991)

This commit is contained in:
Ahmed Harmouche
2024-12-02 11:31:14 +01:00
committed by GitHub
parent e2916ff210
commit 8909dbd82c

View File

@@ -189,7 +189,7 @@ class StableDiffusion:
# make image correct size and scale
x = (x + 1.0) / 2.0
x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
return x.cast(dtypes.uint8)
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
@@ -280,7 +280,7 @@ if __name__ == "__main__":
print(x.shape)
# save image
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
im = Image.fromarray(x.numpy())
print(f"saving {args.out}")
im.save(args.out)
# Open image.