mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix stable_diffusion with fp16 (#5239)
This commit is contained in:
3
.github/workflows/benchmark.yml
vendored
3
.github/workflows/benchmark.yml
vendored
@@ -39,6 +39,8 @@ jobs:
|
||||
# run: echo "RUN_PROCESS_REPLAY=1" >> $GITHUB_ENV
|
||||
- name: Run Stable Diffusion
|
||||
run: JIT=2 THREEFRY=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run Stable Diffusion with fp16
|
||||
run: JIT=2 THREEFRY=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd_fp16.txt
|
||||
- name: Run SDXL
|
||||
run: JIT=2 THREEFRY=1 python3 examples/sdxl.py --seed 0 --noshow | tee sdxl.txt
|
||||
- name: Run model inference benchmark
|
||||
@@ -105,6 +107,7 @@ jobs:
|
||||
matmul.txt
|
||||
matmul_half.txt
|
||||
sd.txt
|
||||
sd_fp16.txt
|
||||
sdxl.txt
|
||||
beautiful_mnist.txt
|
||||
train_cifar.txt
|
||||
|
||||
@@ -592,8 +592,9 @@ if __name__ == "__main__":
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
|
||||
if args.fp16:
|
||||
for l in get_state_dict(model).values():
|
||||
l.replace(l.cast(dtypes.float16).realize())
|
||||
for k,v in get_state_dict(model).items():
|
||||
if k.startswith("model"):
|
||||
v.replace(v.cast(dtypes.float16).realize())
|
||||
|
||||
# run through CLIP to get context
|
||||
tokenizer = ClipTokenizer()
|
||||
|
||||
Reference in New Issue
Block a user