mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
clean up stable diffusion weight loading (#12452)
This commit is contained in:
@@ -269,12 +269,14 @@ if __name__ == "__main__":
|
||||
|
||||
# load in weights
|
||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], verbose=False, strict=False, realize=False)
|
||||
|
||||
if args.fp16:
|
||||
for k,v in get_state_dict(model).items():
|
||||
if k.startswith("model"):
|
||||
v.replace(v.cast(dtypes.float16).realize())
|
||||
v.replace(v.cast(dtypes.float16))
|
||||
|
||||
Tensor.realize(*get_state_dict(model).values())
|
||||
|
||||
# run through CLIP to get context
|
||||
tokenizer = Tokenizer.ClipTokenizer()
|
||||
|
||||
Reference in New Issue
Block a user