From e27fedfc7b54a8b5e6d324660a1323bc9860869a Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Tue, 10 Oct 2023 15:40:51 +0200 Subject: [PATCH] Fix stable diffusion output error on WebGPU (#2032) * Fix stable diffusion on WebGPU * Remove hack, numpy cast only on webgpu * No-copy numpy cast --- examples/stable_diffusion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 0c982bce89..38e04ab0f6 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -651,12 +651,14 @@ if __name__ == "__main__": # make image correct size and scale x = (x + 1.0) / 2.0 - x = (x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255).cast(dtypes.uint8) + x = (x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255) + if Device.DEFAULT != "WEBGPU": x = x.cast(dtypes.uint8) print(x.shape) # save image from PIL import Image - im = Image.fromarray(x.numpy()) + import numpy as np + im = Image.fromarray(x.numpy().astype(np.uint8, copy=False)) print(f"saving {args.out}") im.save(args.out) # Open image.