viz: better ordering of device engines in profiler (#14590)

This commit is contained in:
qazal
2026-02-06 09:08:09 -05:00
committed by GitHub
parent b7e3fbe07e
commit a80fb4e641
2 changed files with 7 additions and 4 deletions

View File

@@ -463,11 +463,11 @@ class TestVizProfiler(BaseTestViz):
def test_layout_order(self):
def fn(): return
for dname in ["TINY", "USER", "TEST:1 N1", "TEST:2 N1", "TEST:1 N2"]:
for dname in ["TINY", "USER", "TEST:1 N1", "TEST:2 N1", "TEST:1 N2", "TEST:1:ENGINE:0", "TEST:1"]:
with cpu_profile("fn", dname): fn()
layout = list(load_profile(cpu_events)["layout"])
self.assertListEqual(layout[:2], ["USER","TINY"])
self.assertListEqual(layout[2:], ["TEST:1 N1","TEST:1 N2", "TEST:2 N1"])
self.assertListEqual(layout[2:], ["TEST:1", "TEST:1:ENGINE:0", "TEST:1 N1","TEST:1 N2", "TEST:2 N1"])
def _alloc(b:int):
a = Tensor.empty(b, device="NULL", dtype=dtypes.char)

View File

@@ -375,9 +375,12 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
def device_sort_fn(k:str) -> tuple[int, str, int]:
order = {"GC": 0, "USER": 1, "TINY": 2, "DISK": 999}
dname = k.split()[0]
dname, *rest = k.split()
dev_rank = next((v for k,v in order.items() if dname.startswith(k)), len(order))
return (dev_rank, dname, len(k))
if len(parts:=dname.split(":")) < 2 or not parts[1].isdigit(): parts.insert(1, "0")
eng_rank = 2 if rest else 1 if len(parts) > 2 else 0
# 3 levels of hierarchy: device class, index in multi device, engine within device
return (dev_rank, parts[1], eng_rank)
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None:
# start by getting the time diffs