mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
stable diffusion profiling (#13441)
* stable diffusion profiling Signed-off-by: George Hotz <geohot@gmail.com> * profile_marker * profile per step * fix slow Context * profile that --------- Signed-off-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Dict, Any
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten, profile_marker
|
||||
from tinygrad.nn import Conv2d, GroupNorm
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder
|
||||
@@ -266,9 +266,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
|
||||
args = parser.parse_args()
|
||||
|
||||
profile_marker("create model")
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
profile_marker("load in weights")
|
||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||
if not args.fakeweights:
|
||||
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||
@@ -281,12 +282,13 @@ if __name__ == "__main__":
|
||||
|
||||
Tensor.realize(*get_state_dict(model).values())
|
||||
|
||||
# run through CLIP to get context
|
||||
profile_marker("run clip (conditional)")
|
||||
tokenizer = Tokenizer.ClipTokenizer()
|
||||
prompt = Tensor([tokenizer.encode(args.prompt)])
|
||||
context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
||||
print("got CLIP context", context.shape)
|
||||
|
||||
profile_marker("run clip (unconditional)")
|
||||
prompt = Tensor([tokenizer.encode("")])
|
||||
unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
||||
print("got unconditional CLIP context", unconditional_context.shape)
|
||||
@@ -310,6 +312,7 @@ if __name__ == "__main__":
|
||||
step_times = []
|
||||
with Context(BEAM=getenv("LATEBEAM")):
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
profile_marker(f"step {len(timesteps)-index-1}")
|
||||
GlobalCounters.reset()
|
||||
st = time.perf_counter_ns()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
@@ -319,24 +322,26 @@ if __name__ == "__main__":
|
||||
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
|
||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||
step_times.append((time.perf_counter_ns() - st)*1e-6)
|
||||
# done with diffusion model
|
||||
del run
|
||||
del model.model
|
||||
|
||||
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
|
||||
min_time = min(step_times)
|
||||
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
|
||||
# upsample latent space to image with autoencoder
|
||||
x = model.decode(latent)
|
||||
profile_marker("run decoder") # upsample latent space to image with autoencoder
|
||||
x = model.decode(latent).realize()
|
||||
print(x.shape)
|
||||
|
||||
# save image
|
||||
profile_marker("save image")
|
||||
im = Image.fromarray(x.numpy())
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
# Open image.
|
||||
if not args.noshow: im.show()
|
||||
|
||||
# validation!
|
||||
if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5:
|
||||
profile_marker("validate")
|
||||
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
|
||||
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
|
||||
assert distance < 3e-3, colored(f"validation failed with {distance=}", "red") # higher distance with WINO
|
||||
|
||||
Reference in New Issue
Block a user