diff --git a/examples/sdv2.py b/examples/sdv2.py index 29b1abb8fd..856cf239ad 100644 --- a/examples/sdv2.py +++ b/examples/sdv2.py @@ -99,6 +99,7 @@ if __name__ == "__main__": parser.add_argument('--timing', action='store_true', help="Print timing per step") parser.add_argument('--noshow', action='store_true', help="Don't show the image") parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16") + parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights") args = parser.parse_args() N = 1 @@ -112,19 +113,22 @@ if __name__ == "__main__": model = StableDiffusionV2(**params) - default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors' - weights_fn = args.weights_fn - if not weights_fn: - weights_url = args.weights_url if args.weights_url else default_weights_url - weights_fn = fetch(weights_url, os.path.basename(str(weights_url))) - with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): - load_state_dict(model, safe_load(weights_fn), strict=False) + if not args.fakeweights: + default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors' + weights_fn = args.weights_fn + if not weights_fn: + weights_url = args.weights_url if args.weights_url else default_weights_url + weights_fn = fetch(weights_url, os.path.basename(str(weights_url))) + + load_state_dict(model, safe_load(weights_fn), strict=False) if args.fp16: for k,v in get_state_dict(model).items(): if k.startswith("model"): - v.replace(v.cast(dtypes.float16).realize()) + v.replace(v.cast(dtypes.float16)) + + Tensor.realize(*get_state_dict(model).values()) c = { "crossattn": model.cond_stage_model(args.prompt) } uc = { "crossattn": model.cond_stage_model("") } diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 644c524476..4650b7e1d9 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -263,14 +263,16 @@ if __name__ == "__main__": parser.add_argument('--timing', action='store_true', help="Print timing per step") parser.add_argument('--seed', type=int, help="Set the random latent seed") parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength") + parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights") args = parser.parse_args() model = StableDiffusion() # load in weights with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): - model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') - load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False) + if not args.fakeweights: + model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') + load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False) if args.fp16: for k,v in get_state_dict(model).items():