mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
stable diffusion --fakeweights (#12810)
This commit is contained in:
@@ -99,6 +99,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
||||
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
||||
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
|
||||
args = parser.parse_args()
|
||||
|
||||
N = 1
|
||||
@@ -112,19 +113,22 @@ if __name__ == "__main__":
|
||||
|
||||
model = StableDiffusionV2(**params)
|
||||
|
||||
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
|
||||
weights_fn = args.weights_fn
|
||||
if not weights_fn:
|
||||
weights_url = args.weights_url if args.weights_url else default_weights_url
|
||||
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
|
||||
|
||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||
load_state_dict(model, safe_load(weights_fn), strict=False)
|
||||
if not args.fakeweights:
|
||||
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
|
||||
weights_fn = args.weights_fn
|
||||
if not weights_fn:
|
||||
weights_url = args.weights_url if args.weights_url else default_weights_url
|
||||
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
|
||||
|
||||
load_state_dict(model, safe_load(weights_fn), strict=False)
|
||||
|
||||
if args.fp16:
|
||||
for k,v in get_state_dict(model).items():
|
||||
if k.startswith("model"):
|
||||
v.replace(v.cast(dtypes.float16).realize())
|
||||
v.replace(v.cast(dtypes.float16))
|
||||
|
||||
Tensor.realize(*get_state_dict(model).values())
|
||||
|
||||
c = { "crossattn": model.cond_stage_model(args.prompt) }
|
||||
uc = { "crossattn": model.cond_stage_model("") }
|
||||
|
||||
@@ -263,14 +263,16 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||
parser.add_argument('--seed', type=int, help="Set the random latent seed")
|
||||
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
|
||||
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
|
||||
args = parser.parse_args()
|
||||
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
|
||||
if not args.fakeweights:
|
||||
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
|
||||
|
||||
if args.fp16:
|
||||
for k,v in get_state_dict(model).items():
|
||||
|
||||
Reference in New Issue
Block a user