mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add user device (#13447)
* add user device * add device_sort_fn (#13448) Co-authored-by: qazal <qazal.software@gmail.com> * linter * order by dname --------- Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
@@ -4203,7 +4203,8 @@ def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
|
||||
else: caller = ""
|
||||
|
||||
token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
|
||||
ret = fn(*args, **kwargs)
|
||||
with cpu_profile(TracingKey(fn.__name__), "USER"):
|
||||
ret = fn(*args, **kwargs)
|
||||
_METADATA.set(token)
|
||||
return ret
|
||||
return _wrapper
|
||||
|
||||
@@ -254,7 +254,12 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
for k in sorted(wave_insts, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[k], depth=2))
|
||||
ctxs.append({"name":"Counters", "steps":steps})
|
||||
|
||||
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=None) -> bytes|None:
|
||||
def device_sort_fn(k):
|
||||
order = {"USER": 0, "TINY": 1}
|
||||
dname, *rest = k.split()
|
||||
return order.get(dname, len(order)+len(rest))
|
||||
|
||||
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None:
|
||||
# start by getting the time diffs
|
||||
for ev in profile:
|
||||
if isinstance(ev,ProfileDeviceEvent): device_ts_diffs[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff)
|
||||
@@ -280,7 +285,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=No
|
||||
scache:dict[str, int] = {}
|
||||
peaks:list[int] = []
|
||||
dtype_size:dict[str, int] = {}
|
||||
for k in sorted(dev_events, key=sort_fn) if sort_fn else dev_events:
|
||||
for k in sorted(dev_events, key=sort_fn):
|
||||
(v:=dev_events[k]).sort(key=lambda e:e[0])
|
||||
layout[k] = timeline_layout(v, start_ts, scache)
|
||||
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)
|
||||
|
||||
Reference in New Issue
Block a user