diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 4650b7e1d9..73a40c5604 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -9,7 +9,7 @@ from typing import Dict, Any from PIL import Image import numpy as np from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit -from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten +from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten, profile_marker from tinygrad.nn import Conv2d, GroupNorm from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder @@ -266,9 +266,10 @@ if __name__ == "__main__": parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights") args = parser.parse_args() + profile_marker("create model") model = StableDiffusion() - # load in weights + profile_marker("load in weights") with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): if not args.fakeweights: model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') @@ -281,12 +282,13 @@ if __name__ == "__main__": Tensor.realize(*get_state_dict(model).values()) - # run through CLIP to get context + profile_marker("run clip (conditional)") tokenizer = Tokenizer.ClipTokenizer() prompt = Tensor([tokenizer.encode(args.prompt)]) context = model.cond_stage_model.transformer.text_model(prompt).realize() print("got CLIP context", context.shape) + profile_marker("run clip (unconditional)") prompt = Tensor([tokenizer.encode("")]) unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize() print("got unconditional CLIP context", unconditional_context.shape) @@ -310,6 +312,7 @@ if __name__ == "__main__": step_times = [] with Context(BEAM=getenv("LATEBEAM")): for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])): + profile_marker(f"step {len(timesteps)-index-1}") GlobalCounters.reset() st = time.perf_counter_ns() t.set_description("%3d %3d" % (index, timestep)) @@ -319,24 +322,26 @@ if __name__ == "__main__": latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance])) if args.timing: Device[Device.DEFAULT].synchronize() step_times.append((time.perf_counter_ns() - st)*1e-6) + # done with diffusion model del run + del model.model if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")): min_time = min(step_times) assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms" - # upsample latent space to image with autoencoder - x = model.decode(latent) + profile_marker("run decoder") # upsample latent space to image with autoencoder + x = model.decode(latent).realize() print(x.shape) - # save image + profile_marker("save image") im = Image.fromarray(x.numpy()) print(f"saving {args.out}") im.save(args.out) # Open image. if not args.noshow: im.show() - # validation! if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5: + profile_marker("validate") ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png"))) distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item() assert distance < 3e-3, colored(f"validation failed with {distance=}", "red") # higher distance with WINO diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 87c67dd47b..0730395c7c 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -147,8 +147,10 @@ def temp(x:str, append_user:bool=False) -> str: class Context(contextlib.ContextDecorator): def __init__(self, **kwargs): self.kwargs = kwargs def __enter__(self): - self.old_context:dict[str, int] = {k:v.value for k,v in ContextVar._cache.items()} - for k,v in self.kwargs.items(): ContextVar._cache[k].value = v + self.old_context:dict[str, int] = {} + for k,v in self.kwargs.items(): + self.old_context[k] = ContextVar._cache[k].value + ContextVar._cache[k].value = v def __exit__(self, *args): for k,v in self.old_context.items(): ContextVar._cache[k].value = v @@ -279,7 +281,7 @@ class ProfilePointEvent(ProfileEvent): cpu_events:list[ProfileEvent] = [] @contextlib.contextmanager -def cpu_profile(name:str|TracingKey, device="CPU", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]: +def cpu_profile(name:str|TracingKey, device="TINY", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]: res = ProfileRangeEvent(device, name, perf_counter_us(), is_copy=is_copy) try: yield res finally: @@ -382,7 +384,11 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip # *** Exec helpers -def system(cmd, **kwargs): return subprocess.check_output(cmd.split(), **kwargs).decode().strip() +def system(cmd:str, **kwargs) -> str: + st = time.perf_counter() + ret = subprocess.check_output(cmd.split(), **kwargs).decode().strip() + if DEBUG >= 1: print(f"system: '{cmd}' returned {len(ret)} bytes in {(time.perf_counter() - st)*1e3:.2f} ms") + return ret def cpu_objdump(lib, objdump_tool='objdump'): with tempfile.NamedTemporaryFile(delete=True) as f: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1c1b74e9ee..9d355b38b5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -260,8 +260,8 @@ class Tensor(OpMixin): _apply_map_to_tensors(remove_assign_map, name="Remove After") # create the schedule - schedule, var_vals = create_schedule_with_vars(sink) - schedule = memory_planner(schedule) + with cpu_profile(TracingKey("toposort schedule")): schedule, var_vals = create_schedule_with_vars(sink) + with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms") return schedule, var_vals diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 2f47aff20b..2c684b9b57 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -275,7 +275,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return tuple(1 if i in axis_arg else s for i,s in enumerate(ps)) # elementwise ops keep the shape the same. all inputs with shape must match - if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}): + if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}): # TODO: remove this hack for 3 op assign input_shapes = [x._shape for x in (self.src[:2] if self.op is Ops.ASSIGN else self.src) if x._shape is not None] if len(input_shapes) == 0: return None