viz: sort tracks in timeline (#14591)

* viz: sort devices in timeline

* fix

* rev

* upd

* skip
This commit is contained in:
nimlgen
2026-02-12 10:51:41 +03:00
committed by GitHub
parent 025049c521
commit 14a1991da6
2 changed files with 55 additions and 12 deletions

View File

@@ -1,4 +1,4 @@
import unittest, decimal, json, struct
import unittest, decimal, json, struct, sys
from dataclasses import dataclass
from typing import Generator
@@ -6,7 +6,7 @@ from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatch
from tinygrad.uop.symbolic import sym
from tinygrad.dtype import dtypes
from tinygrad.helpers import PROFILE, colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker
from tinygrad.helpers import VIZ, cpu_profile
from tinygrad.helpers import VIZ, cpu_profile, ProfilePointEvent
from tinygrad.device import Buffer
@track_rewrites(name=True)
@@ -424,6 +424,51 @@ class TestVizProfiler(BaseTestViz):
self.assertEqual(graph_events[0]['st'], nv_events[0]['st'])
self.assertEqual(graph_events[0]['st']+graph_events[0]['dur'], sdma_events[0]['st']+sdma_events[0]['dur'])
def test_block_ordering(self):
prof = [ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000)),
ProfileDeviceEvent(device='NV:1', tdiff=decimal.Decimal(-500)),
ProfileDeviceEvent(device='NV:SDMA:0', tdiff=decimal.Decimal(-100)),
ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileRangeEvent(device='NV:1', name='E_3', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileRangeEvent(device='NV:SDMA:0', name='COPY', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_2', st_id=0, en_id=1)],
deps=[[]], sigs=[decimal.Decimal(1000), decimal.Decimal(1010)])]
j = load_profile(prof)
# graph grouped with its device, memory at the end
self.assertListEqual(list(j['layout']), ['NV', 'NV Graph', 'NV:SDMA:0', 'NV:1'])
@unittest.skipIf(sys.platform == 'win32', "TODO: ops_amd import fails on windows")
def test_multi_sdma_ordering(self):
props = {"gfx_target_version": 0}
D, St, En = decimal.Decimal, decimal.Decimal(1000), decimal.Decimal(1010)
prof = [# 2 AMD GPUs, 2 SDMA engines each
ProfileDeviceEvent(device='AMD', tdiff=D(-1000), props=props),
ProfileDeviceEvent(device='AMD:1', tdiff=D(-900), props=props),
ProfileDeviceEvent(device='AMD:SDMA:0', tdiff=D(-100), props=props),
ProfileDeviceEvent(device='AMD:SDMA:1', tdiff=D(-80), props=props),
ProfileDeviceEvent(device='AMD:1:SDMA:0', tdiff=D(-60), props=props),
ProfileDeviceEvent(device='AMD:1:SDMA:1', tdiff=D(-40), props=props),
# compute + copy events
ProfileRangeEvent(device='AMD', name='E_1', st=St, en=En),
ProfileRangeEvent(device='AMD:1', name='E_2', st=St, en=En),
ProfileRangeEvent(device='AMD:SDMA:0', name='COPY0', st=St, en=En),
ProfileRangeEvent(device='AMD:SDMA:1', name='COPY1', st=St, en=En),
ProfileRangeEvent(device='AMD:1:SDMA:0', name='COPY2', st=St, en=En),
ProfileRangeEvent(device='AMD:1:SDMA:1', name='COPY3', st=St, en=En),
# graph spanning compute + copy on GPU 0
ProfileGraphEvent(ents=[ProfileGraphEntry(device='AMD', name='E_1', st_id=0, en_id=1),
ProfileGraphEntry(device='AMD:SDMA:0', name='COPY0', st_id=2, en_id=3)],
deps=[[], [0]], sigs=[St, En, St, En]),
# memory alloc on both GPUs
ProfilePointEvent(device='AMD', name='alloc', key=0, arg={"sz":1024, "dtype":dtypes.float}, ts=St),
ProfilePointEvent(device='AMD:1', name='alloc', key=1, arg={"sz":512, "dtype":dtypes.float}, ts=St)]
j = load_profile(prof)
# graph grouped with its device, memory at the end
self.assertListEqual(list(j['layout']),
['AMD', 'AMD Graph', 'AMD:SDMA:0', 'AMD:SDMA:1',
'AMD:1', 'AMD:1:SDMA:0', 'AMD:1:SDMA:1',
'AMD Memory', 'AMD:1 Memory'])
def test_bytes_per_kernel(self):
step = 10
n_events = 1_000
@@ -463,11 +508,11 @@ class TestVizProfiler(BaseTestViz):
def test_layout_order(self):
def fn(): return
for dname in ["TINY", "USER", "TEST:1 N1", "TEST:2 N1", "TEST:1 N2", "TEST:1:ENGINE:0", "TEST:1"]:
for dname in ["TINY", "USER", "TEST:1 N1", "TEST:2 N1", "TEST:1 N2", "TEST:1:ENGINE:0", "TEST:1:ENGINE:0 N1", "TEST:1"]:
with cpu_profile("fn", dname): fn()
layout = list(load_profile(cpu_events)["layout"])
self.assertListEqual(layout[:2], ["USER","TINY"])
self.assertListEqual(layout[2:], ["TEST:1", "TEST:1:ENGINE:0", "TEST:1 N1","TEST:1 N2", "TEST:2 N1"])
self.assertListEqual(layout[2:], ["TEST:1", "TEST:1 N1", "TEST:1 N2", "TEST:1:ENGINE:0", "TEST:1:ENGINE:0 N1", "TEST:2 N1"])
def _alloc(b:int):
a = Tensor.empty(b, device="NULL", dtype=dtypes.char)

View File

@@ -368,14 +368,12 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
events.append(ProfileRangeEvent(f"SIMD:{occ.simd}", f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)),Decimal(occ.time)))
return cu_events, list(units), wave_insts
def device_sort_fn(k:str) -> tuple[int, str, int]:
order = {"GC": 0, "USER": 1, "TINY": 2, "DISK": 999}
dname, *rest = k.split()
dev_rank = next((v for k,v in order.items() if dname.startswith(k)), len(order))
if len(parts:=dname.split(":")) < 2 or not parts[1].isdigit(): parts.insert(1, "0")
eng_rank = 2 if rest else 1 if len(parts) > 2 else 0
# 3 levels of hierarchy: device class, index in multi device, engine within device
return (dev_rank, parts[1], eng_rank)
def device_sort_fn(k:str) -> tuple:
special = {"GC": 0, "USER": 1, "TINY": 2, "ALLDEVS":100, "DISK": 999}
is_memory = k.endswith(" Memory")
p = k.split(" ")[0].split(":")
dev_base = p[0] if len(p) < 2 or not p[1].isdigit() else f"{p[0]}:{p[1]}"
return (is_memory, special.get(p[0], special['ALLDEVS']), dev_base, k)
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None:
# start by getting the time diffs