stable diffusion profiling (#13441)

* stable diffusion profiling

Signed-off-by: George Hotz <geohot@gmail.com>

* profile_marker

* profile per step

* fix slow Context

* profile that

---------

Signed-off-by: George Hotz <geohot@gmail.com>
This commit is contained in:
George Hotz
2025-11-24 15:25:45 -08:00
committed by GitHub
parent 18cfb54736
commit cc5e6323ac
4 changed files with 25 additions and 14 deletions

View File

@@ -9,7 +9,7 @@ from typing import Dict, Any
from PIL import Image
import numpy as np
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten, profile_marker
from tinygrad.nn import Conv2d, GroupNorm
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder
@@ -266,9 +266,10 @@ if __name__ == "__main__":
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
args = parser.parse_args()
profile_marker("create model")
model = StableDiffusion()
# load in weights
profile_marker("load in weights")
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
if not args.fakeweights:
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
@@ -281,12 +282,13 @@ if __name__ == "__main__":
Tensor.realize(*get_state_dict(model).values())
# run through CLIP to get context
profile_marker("run clip (conditional)")
tokenizer = Tokenizer.ClipTokenizer()
prompt = Tensor([tokenizer.encode(args.prompt)])
context = model.cond_stage_model.transformer.text_model(prompt).realize()
print("got CLIP context", context.shape)
profile_marker("run clip (unconditional)")
prompt = Tensor([tokenizer.encode("")])
unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
print("got unconditional CLIP context", unconditional_context.shape)
@@ -310,6 +312,7 @@ if __name__ == "__main__":
step_times = []
with Context(BEAM=getenv("LATEBEAM")):
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
profile_marker(f"step {len(timesteps)-index-1}")
GlobalCounters.reset()
st = time.perf_counter_ns()
t.set_description("%3d %3d" % (index, timestep))
@@ -319,24 +322,26 @@ if __name__ == "__main__":
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
if args.timing: Device[Device.DEFAULT].synchronize()
step_times.append((time.perf_counter_ns() - st)*1e-6)
# done with diffusion model
del run
del model.model
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
min_time = min(step_times)
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
# upsample latent space to image with autoencoder
x = model.decode(latent)
profile_marker("run decoder") # upsample latent space to image with autoencoder
x = model.decode(latent).realize()
print(x.shape)
# save image
profile_marker("save image")
im = Image.fromarray(x.numpy())
print(f"saving {args.out}")
im.save(args.out)
# Open image.
if not args.noshow: im.show()
# validation!
if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5:
profile_marker("validate")
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
assert distance < 3e-3, colored(f"validation failed with {distance=}", "red") # higher distance with WINO

View File

@@ -147,8 +147,10 @@ def temp(x:str, append_user:bool=False) -> str:
class Context(contextlib.ContextDecorator):
def __init__(self, **kwargs): self.kwargs = kwargs
def __enter__(self):
self.old_context:dict[str, int] = {k:v.value for k,v in ContextVar._cache.items()}
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
self.old_context:dict[str, int] = {}
for k,v in self.kwargs.items():
self.old_context[k] = ContextVar._cache[k].value
ContextVar._cache[k].value = v
def __exit__(self, *args):
for k,v in self.old_context.items(): ContextVar._cache[k].value = v
@@ -279,7 +281,7 @@ class ProfilePointEvent(ProfileEvent):
cpu_events:list[ProfileEvent] = []
@contextlib.contextmanager
def cpu_profile(name:str|TracingKey, device="CPU", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]:
def cpu_profile(name:str|TracingKey, device="TINY", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]:
res = ProfileRangeEvent(device, name, perf_counter_us(), is_copy=is_copy)
try: yield res
finally:
@@ -382,7 +384,11 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip
# *** Exec helpers
def system(cmd, **kwargs): return subprocess.check_output(cmd.split(), **kwargs).decode().strip()
def system(cmd:str, **kwargs) -> str:
st = time.perf_counter()
ret = subprocess.check_output(cmd.split(), **kwargs).decode().strip()
if DEBUG >= 1: print(f"system: '{cmd}' returned {len(ret)} bytes in {(time.perf_counter() - st)*1e3:.2f} ms")
return ret
def cpu_objdump(lib, objdump_tool='objdump'):
with tempfile.NamedTemporaryFile(delete=True) as f:

View File

@@ -260,8 +260,8 @@ class Tensor(OpMixin):
_apply_map_to_tensors(remove_assign_map, name="Remove After")
# create the schedule
schedule, var_vals = create_schedule_with_vars(sink)
schedule = memory_planner(schedule)
with cpu_profile(TracingKey("toposort schedule")): schedule, var_vals = create_schedule_with_vars(sink)
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
return schedule, var_vals

View File

@@ -275,7 +275,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
# elementwise ops keep the shape the same. all inputs with shape must match
if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
# TODO: remove this hack for 3 op assign
input_shapes = [x._shape for x in (self.src[:2] if self.op is Ops.ASSIGN else self.src) if x._shape is not None]
if len(input_shapes) == 0: return None