mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
stable diffusion --fakeweights (#12810)
This commit is contained in:
@@ -99,6 +99,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
||||||
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
||||||
|
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
N = 1
|
N = 1
|
||||||
@@ -112,19 +113,22 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
model = StableDiffusionV2(**params)
|
model = StableDiffusionV2(**params)
|
||||||
|
|
||||||
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
|
|
||||||
weights_fn = args.weights_fn
|
|
||||||
if not weights_fn:
|
|
||||||
weights_url = args.weights_url if args.weights_url else default_weights_url
|
|
||||||
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
|
|
||||||
|
|
||||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||||
load_state_dict(model, safe_load(weights_fn), strict=False)
|
if not args.fakeweights:
|
||||||
|
default_weights_url = 'https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'
|
||||||
|
weights_fn = args.weights_fn
|
||||||
|
if not weights_fn:
|
||||||
|
weights_url = args.weights_url if args.weights_url else default_weights_url
|
||||||
|
weights_fn = fetch(weights_url, os.path.basename(str(weights_url)))
|
||||||
|
|
||||||
|
load_state_dict(model, safe_load(weights_fn), strict=False)
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
for k,v in get_state_dict(model).items():
|
for k,v in get_state_dict(model).items():
|
||||||
if k.startswith("model"):
|
if k.startswith("model"):
|
||||||
v.replace(v.cast(dtypes.float16).realize())
|
v.replace(v.cast(dtypes.float16))
|
||||||
|
|
||||||
|
Tensor.realize(*get_state_dict(model).values())
|
||||||
|
|
||||||
c = { "crossattn": model.cond_stage_model(args.prompt) }
|
c = { "crossattn": model.cond_stage_model(args.prompt) }
|
||||||
uc = { "crossattn": model.cond_stage_model("") }
|
uc = { "crossattn": model.cond_stage_model("") }
|
||||||
|
|||||||
@@ -263,14 +263,16 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||||
parser.add_argument('--seed', type=int, help="Set the random latent seed")
|
parser.add_argument('--seed', type=int, help="Set the random latent seed")
|
||||||
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
|
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
|
||||||
|
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model = StableDiffusion()
|
model = StableDiffusion()
|
||||||
|
|
||||||
# load in weights
|
# load in weights
|
||||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||||
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
if not args.fakeweights:
|
||||||
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
|
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||||
|
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
for k,v in get_state_dict(model).items():
|
for k,v in get_state_dict(model).items():
|
||||||
|
|||||||
Reference in New Issue
Block a user