From 8909dbd82c811628e96e390ad38658148338cd96 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Mon, 2 Dec 2024 11:31:14 +0100 Subject: [PATCH] Remove wgpu specific checks from stable diffusion example (#7991) --- examples/stable_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index ef54f8a888..be4305771d 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -189,7 +189,7 @@ class StableDiffusion: # make image correct size and scale x = (x + 1.0) / 2.0 x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255 - return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x + return x.cast(dtypes.uint8) def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance): e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance) @@ -280,7 +280,7 @@ if __name__ == "__main__": print(x.shape) # save image - im = Image.fromarray(x.numpy().astype(np.uint8, copy=False)) + im = Image.fromarray(x.numpy()) print(f"saving {args.out}") im.save(args.out) # Open image.