clean up stable diffusion weight loading (#12452)

This commit is contained in:
George Hotz
2025-10-09 11:13:11 +08:00
committed by GitHub
parent 20d98b19c3
commit 6e6059dde0

View File

@@ -269,12 +269,14 @@ if __name__ == "__main__":
# load in weights
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], verbose=False, strict=False, realize=False)
if args.fp16:
for k,v in get_state_dict(model).items():
if k.startswith("model"):
v.replace(v.cast(dtypes.float16).realize())
v.replace(v.cast(dtypes.float16))
Tensor.realize(*get_state_dict(model).values())
# run through CLIP to get context
tokenizer = Tokenizer.ClipTokenizer()