add user device (#13447)

* add user device

* add device_sort_fn (#13448)

Co-authored-by: qazal <qazal.software@gmail.com>

* linter

* order by dname

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz
2025-11-24 23:25:45 -08:00
committed by GitHub
parent 241e533451
commit 5373fd2d66
2 changed files with 9 additions and 3 deletions

View File

@@ -4203,7 +4203,8 @@ def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
else: caller = ""
token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
ret = fn(*args, **kwargs)
with cpu_profile(TracingKey(fn.__name__), "USER"):
ret = fn(*args, **kwargs)
_METADATA.set(token)
return ret
return _wrapper

View File

@@ -254,7 +254,12 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
for k in sorted(wave_insts, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[k], depth=2))
ctxs.append({"name":"Counters", "steps":steps})
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=None) -> bytes|None:
def device_sort_fn(k):
order = {"USER": 0, "TINY": 1}
dname, *rest = k.split()
return order.get(dname, len(order)+len(rest))
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None:
# start by getting the time diffs
for ev in profile:
if isinstance(ev,ProfileDeviceEvent): device_ts_diffs[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff)
@@ -280,7 +285,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=No
scache:dict[str, int] = {}
peaks:list[int] = []
dtype_size:dict[str, int] = {}
for k in sorted(dev_events, key=sort_fn) if sort_fn else dev_events:
for k in sorted(dev_events, key=sort_fn):
(v:=dev_events[k]).sort(key=lambda e:e[0])
layout[k] = timeline_layout(v, start_ts, scache)
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)