From 5373fd2d665f71f1094baa77caeb25e7676f5698 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:25:45 -0800 Subject: [PATCH] add user device (#13447) * add user device * add device_sort_fn (#13448) Co-authored-by: qazal * linter * order by dname --------- Co-authored-by: qazal --- tinygrad/tensor.py | 3 ++- tinygrad/viz/serve.py | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d0ffe374af..7bf7637e34 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -4203,7 +4203,8 @@ def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]: else: caller = "" token = _METADATA.set(Metadata(name=fn.__name__, caller=caller)) - ret = fn(*args, **kwargs) + with cpu_profile(TracingKey(fn.__name__), "USER"): + ret = fn(*args, **kwargs) _METADATA.set(token) return ret return _wrapper diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index aeef076aed..4f33502921 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -254,7 +254,12 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: for k in sorted(wave_insts, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[k], depth=2)) ctxs.append({"name":"Counters", "steps":steps}) -def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=None) -> bytes|None: +def device_sort_fn(k): + order = {"USER": 0, "TINY": 1} + dname, *rest = k.split() + return order.get(dname, len(order)+len(rest)) + +def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None: # start by getting the time diffs for ev in profile: if isinstance(ev,ProfileDeviceEvent): device_ts_diffs[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff) @@ -280,7 +285,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=No scache:dict[str, int] = {} peaks:list[int] = [] dtype_size:dict[str, int] = {} - for k in sorted(dev_events, key=sort_fn) if sort_fn else dev_events: + for k in sorted(dev_events, key=sort_fn): (v:=dev_events[k]).sort(key=lambda e:e[0]) layout[k] = timeline_layout(v, start_ts, scache) layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)