From 6e6059dde0fd77d00498e6d6a61666d664b60ff4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 9 Oct 2025 11:13:11 +0800 Subject: [PATCH] clean up stable diffusion weight loading (#12452) --- examples/stable_diffusion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index fe85aaffab..64a8921740 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -269,12 +269,14 @@ if __name__ == "__main__": # load in weights with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): - load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False) + load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], verbose=False, strict=False, realize=False) if args.fp16: for k,v in get_state_dict(model).items(): if k.startswith("model"): - v.replace(v.cast(dtypes.float16).realize()) + v.replace(v.cast(dtypes.float16)) + + Tensor.realize(*get_state_dict(model).values()) # run through CLIP to get context tokenizer = Tokenizer.ClipTokenizer()