stable diffusion --fakeweights (#12810)

This commit is contained in:
Sieds Lykles
2025-10-20 12:41:06 +02:00
committed by GitHub
parent b5e36e3c6c
commit 1e93d19ee3
2 changed files with 16 additions and 10 deletions

View File

@@ -99,6 +99,7 @@ if __name__ == "__main__":
parser.add_argument('--timing', action='store_true', help="Print timing per step")
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
args = parser.parse_args()
N = 1
@@ -112,19 +113,22 @@ if __name__ == "__main__":
model = StableDiffusionV2(**params)
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
weights_fn = args.weights_fn
if not weights_fn:
weights_url = args.weights_url if args.weights_url else default_weights_url
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
load_state_dict(model, safe_load(weights_fn), strict=False)
if not args.fakeweights:
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
weights_fn = args.weights_fn
if not weights_fn:
weights_url = args.weights_url if args.weights_url else default_weights_url
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
load_state_dict(model, safe_load(weights_fn), strict=False)
if args.fp16:
for k,v in get_state_dict(model).items():
if k.startswith("model"):
v.replace(v.cast(dtypes.float16).realize())
v.replace(v.cast(dtypes.float16))
Tensor.realize(*get_state_dict(model).values())
c = { "crossattn": model.cond_stage_model(args.prompt) }
uc = { "crossattn": model.cond_stage_model("") }

View File

@@ -263,14 +263,16 @@ if __name__ == "__main__":
parser.add_argument('--timing', action='store_true', help="Print timing per step")
parser.add_argument('--seed', type=int, help="Set the random latent seed")
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
args = parser.parse_args()
model = StableDiffusion()
# load in weights
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
if not args.fakeweights:
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
if args.fp16:
for k,v in get_state_dict(model).items():