From ec5b7a249e9409809e7076da5285f711027d31aa Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 26 Mar 2026 14:16:15 +0200 Subject: [PATCH] viz: refactor sqtt timeline builder (#15494) * viz: refactor sqtt timeline builder * barrier maps to waves * clean up cli --- extra/viz/cli.py | 2 +- tinygrad/viz/serve.py | 35 ++++++++++++++++++++--------------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/extra/viz/cli.py b/extra/viz/cli.py index ad768c3a4b..980987b120 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -90,7 +90,7 @@ def main(): pkt_idxs:dict[str, itertools.count] = {} dispatch_to_pc:dict[str, int] = {} for e in viz.sqtt_timeline(*sqtt_data): - if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg; continue + if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg if not isinstance(e, ProfileRangeEvent): continue op_name, info = e.name.display_name, e.name.ret or "" color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 8af1cfc3e6..1c56642ec9 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -344,29 +344,35 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent, pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()} row_ends:dict[str, Decimal] = {} row_counts:dict[str, itertools.count] = {} - curr_barrier:dict[str, ProfileRangeEvent] = {} + curr_barrier:dict[int, ProfileRangeEvent] = {} exec_pending:dict[str, list[str]] = {} - NS_PER_TICK = 10 # 100MHz - prev_pair:tuple[int, int]|None = None # (shader, realtime) is_cdna = target.startswith("gfx9") dispatch_to_exec = {"WMMA":"VALU", "VALU":"VALU", "VALU1":"VALU", "VALUT":"VALU", "VALUB":"VALU", "VALUINST":"VALU", "VINTERP":"VALU", "SGMEM":"VMEM", "FLAT":"VMEM", "LDS":"LDS", "SALU":"SALU", "SMEM":"SALU", "VMEM":"VMEM"} def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> Generator[ProfileEvent, None, None]: row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}" + # default length is 1 cycle + duration = 1 + # exec links to dispatch, dispatch links to PC + link = f"PC:{info.pc}" if info else None + if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): + link = f"LINK:{exec_pending[name].pop(0)}" + # queue inst dispatches + idx = next(row_counts.setdefault(row, itertools.count(0))) + if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None: + exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}") + # construct and yield the event for this packet if row not in row_ends: yield ProfilePointEvent(row, "JSON", "pcMap", pc_map, ts=Decimal(0)) - # barrier on this row extends to fill the time our wave was waiting - if (barrier:=curr_barrier.pop(row, None)) is not None: barrier.en = Decimal(p._time) - e = ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+1)) + yield (e:=ProfileRangeEvent(row, TracingKey(op or name, ret=link), Decimal(p._time), Decimal(p._time+duration))) # allow CDNA packets to overlap, NOT allowed on RDNA. if (et:=row_ends.get(row)) is not None and e.st < et and not is_cdna: raise RuntimeError(f"packet {p} overlaps another packet in {row}.") row_ends[row] = unwrap(e.en) - idx = next(row_counts.setdefault(row, itertools.count(0))) - if name == "BARRIER": curr_barrier[row] = e - # queue for exec linking - if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None: - exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}") - if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): e.name = TracingKey(op or name, ret=f"LINK:{exec_pending[name].pop(0)}") - yield e + # barrier on this wave extends to fill the time it was waiting + if wave is not None: + if (barrier:=curr_barrier.pop(wave, None)) is not None: barrier.en = Decimal(p._time) + if name == "BARRIER": curr_barrier[wave] = e + NS_PER_TICK = 10 # 100MHz + prev_pair:tuple[int, int]|None = None # (shader, realtime) for p, info in map_insts(data, lib, target): if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker: pair = (p._time, p.delta) @@ -384,8 +390,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent, if isinstance(p, WAVERDY): for wave in range(16): if p.mask & (1 << wave): - row = f"WAVE:{wave}" - if row in curr_barrier: yield from add("WAVERDY", p, wave=wave) + if wave in curr_barrier: yield from add("WAVERDY", p, wave=wave) if isinstance(p, (VMEMEXEC, ALUEXEC)): name = str(p.src).split('.')[1] if name == "VALU_SALU":