diff --git a/test/amd/test_sqttmap.py b/test/amd/test_sqttmap.py index 4d0a5fa892..a4a46217c9 100644 --- a/test/amd/test_sqttmap.py +++ b/test/amd/test_sqttmap.py @@ -94,6 +94,8 @@ class TestSQTTMapBase(unittest.TestCase): # sopk/immediates don't get ALU/MEM EXEC if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL", "WAVEEND", "WAVERDY"}: insts += 1 + # OTHER_ is its own stream, it's the INST from other SIMDs that share the same EXEC. + elif e.device.startswith("OTHER"): continue else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}") self.assertEqual(execs, insts) diff --git a/tinygrad/renderer/amd/sqtt.py b/tinygrad/renderer/amd/sqtt.py index 6b54ae60ee..832ffe393b 100644 --- a/tinygrad/renderer/amd/sqtt.py +++ b/tinygrad/renderer/amd/sqtt.py @@ -657,8 +657,6 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I elif isinstance(p, WAVEEND): pc = wave_pc.pop(p.wave) yield (p, InstructionInfo(pc, p.wave, s_endpgm())) - # skip OTHER_ instructions, they don't belong to this unit - elif isinstance(p, (INST, INST_RDNA4)) and p.op.name.startswith("OTHER_"): pass elif isinstance(p, IMMEDIATE_MASK): # immediate mask may yield multiple times per packet for wave in range(16): @@ -668,7 +666,8 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I assert type(inst).__name__ == "SOPP", f"IMMEDIATE_MASK packet must map to SOPP, got {inst}" wave_pc[wave] += inst.size() yield (p, InstructionInfo(pc, wave, inst)) - elif isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)): + # map INST events on this SIMD to the program counter, we know the waves + elif isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)) and not (isinstance(p, (INST, INST_RDNA4)) and p.op.name.startswith("OTHER_")): inst = pc_map[pc:=wave_pc[p.wave]] # s_delay_alu, s_wait_alu and s_barrier_wait instructions are skipped while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU", "S_BARRIER_WAIT"}: @@ -684,7 +683,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I else: wave_pc[p.wave] += inst.size() yield (p, InstructionInfo(pc, p.wave, inst)) - # for all other packets (VMEMEXEC, ALUEXEC, etc.), yield with None + # for all other packets (VMEMEXEC, ALUEXEC, OTHER_ INST, etc.), yield with None else: yield (p, None) # ═══════════════════════════════════════════════════════════════════════════════ diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 1c56642ec9..bc32718e27 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -349,21 +349,23 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent, is_cdna = target.startswith("gfx9") dispatch_to_exec = {"WMMA":"VALU", "VALU":"VALU", "VALU1":"VALU", "VALUT":"VALU", "VALUB":"VALU", "VALUINST":"VALU", "VINTERP":"VALU", "SGMEM":"VMEM", "FLAT":"VMEM", "LDS":"LDS", "SALU":"SALU", "SMEM":"SALU", "VMEM":"VMEM"} - def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> Generator[ProfileEvent, None, None]: - row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}" + def add(name:str, p:PacketType, wave:int|None=None, info:InstructionInfo|None=None) -> Generator[ProfileEvent, None, None]: + row = "OTHER" if name.startswith("OTHER_") else f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None \ + else f"{p.__class__.__name__}:0 {name.replace('_ALT', '')}" # default length is 1 cycle duration = 1 # exec links to dispatch, dispatch links to PC link = f"PC:{info.pc}" if info else None - if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): + if isinstance(p, (ALUEXEC, VMEMEXEC)): link = f"LINK:{exec_pending[name].pop(0)}" # queue inst dispatches idx = next(row_counts.setdefault(row, itertools.count(0))) - if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None: + if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.replace("OTHER_", "").split("_")[0])) is not None: + if name.startswith("OTHER_"): exec_type = f"{exec_type}_ALT" exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}") # construct and yield the event for this packet if row not in row_ends: yield ProfilePointEvent(row, "JSON", "pcMap", pc_map, ts=Decimal(0)) - yield (e:=ProfileRangeEvent(row, TracingKey(op or name, ret=link), Decimal(p._time), Decimal(p._time+duration))) + yield (e:=ProfileRangeEvent(row, TracingKey(name, ret=link), Decimal(p._time), Decimal(p._time+duration))) # allow CDNA packets to overlap, NOT allowed on RDNA. if (et:=row_ends.get(row)) is not None and e.st < et and not is_cdna: raise RuntimeError(f"packet {p} overlaps another packet in {row}.") row_ends[row] = unwrap(e.en) @@ -397,7 +399,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent, yield from add("VALU", p) yield from add("SALU", p) else: - yield from add(name.replace("_ALT", ""), p, op=name) + yield from add(name, p) # ** SQTT OCC only unpacks wave start, end time and SIMD location