diff --git a/test/amd/test_sqttmap.py b/test/amd/test_sqttmap.py index eec71b8329..909159ca4c 100644 --- a/test/amd/test_sqttmap.py +++ b/test/amd/test_sqttmap.py @@ -78,8 +78,9 @@ class TestSQTTMapBase(unittest.TestCase): for event in events: if (p:=kern_events.get(event.kern)) is None: continue with self.subTest(example=name, kern=event.kern): + # skip if there's no SQTT frequency data if not (timeline:=sqtt_timeline(event.blob, p.lib, target)): continue - frequency = [e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"] + if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue mean = sum(frequency) / len(frequency) variance = sum((v - mean) ** 2 for v in frequency) / len(frequency) self.assertGreater(mean, 0) @@ -110,5 +111,9 @@ class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100" class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200" +class TestSQTTMapCDNA(TestSQTTMapBase): + target = "gfx950" + def test_rocprof_inst_traces_match(self): self.skipTest("requires timestamp patching to match rocprof, currently it's off by a few cycles") + if __name__ == "__main__": unittest.main() diff --git a/test/backend/test_asm_gemm.py b/test/backend/test_asm_gemm.py index f3b2066574..d013ea9804 100644 --- a/test/backend/test_asm_gemm.py +++ b/test/backend/test_asm_gemm.py @@ -87,6 +87,9 @@ class TestGemmLarge(unittest.TestCase): if not is_cdna4(): self.skipTest("very slow on non mi350x") + @Context(ASM_GEMM=1) + def test_empty(self): (Tensor.empty(N:=getenv("N", 4096), N, dtype=dtypes.half)@Tensor.empty(N, N, dtype=dtypes.half)).realize() + def test_tiny(self): verify_asm_gemm(1, 256, 256, 64) def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half) def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336) diff --git a/tinygrad/renderer/amd/sqtt.py b/tinygrad/renderer/amd/sqtt.py index 43a1ce62ac..67d9b1b715 100644 --- a/tinygrad/renderer/amd/sqtt.py +++ b/tinygrad/renderer/amd/sqtt.py @@ -139,6 +139,24 @@ class InstOpRDNA4(Enum): OTHER_VMEM = 0xbd OTHER_VMEM_5 = 0xc1 +class InstOpCDNA(Enum): + SMEM_RD = 0 + SALU_32 = 1 + VMEM_RD = 2 + VMEM_WR = 3 + FLAT_WR = 4 + VALU_32 = 5 + LDS = 6 + PC = 7 + JUMP = 12 + NEXT = 13 + FLAT_RD = 14 + OTHER_MSG = 15 + SMEM_WR = 16 + SALU_64 = 17 + VALU_64 = 18 + VALU_MAI = 28 + # ═══════════════════════════════════════════════════════════════════════════════ # PACKET TYPE BASE CLASS # ═══════════════════════════════════════════════════════════════════════════════ @@ -448,7 +466,7 @@ class CDNA_INST(PacketType): encoding = bits[3:0] == 10 wave = bits[8:5] simd = bits[10:9] - inst_type = bits[15:11] + op = bits[15:11].enum(InstOpCDNA) class CDNA_INST_PC(PacketType): """pkt_fmt=11: 64-bit (MsgInstPc)""" @@ -612,7 +630,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I def simd_select(p) -> bool: return getattr(p, "cu", 0) == 0 and getattr(p, "simd", 0) == 0 for p in decode(data): if not simd_select(p): continue - if isinstance(p, (WAVESTART, WAVESTART_RDNA4)): + if isinstance(p, (WAVESTART, WAVESTART_RDNA4, CDNA_WAVESTART)): assert p.wave not in wave_pc, "only one inflight wave per unit" wave_pc[p.wave] = next(iter(pc_map)) elif isinstance(p, WAVEEND): diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 4797e61958..97ef8849dc 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -252,7 +252,7 @@ function selectShape(key) { const timelineScale = () => d3.scaleLinear().domain([data.first, data.dur]).range([0, document.getElementById("timeline").clientWidth]) function timeAtCycle(clk) { - if (clk < data.instSt || clk > data.instEt) return "-"; + if (clk < data.instSt || clk > data.instEt || data.tracks.get("Shader Clock") == null) return "-"; let cur = data.instSt, ns = 0, freq = null; // walk through all frequency changes and accumulate time in nanoseconds for (const [s, v] of data.tracks.get("Shader Clock").valueMap) { diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 1b7f9f246f..0020bf4588 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -339,18 +339,20 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None: def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: from tinygrad.renderer.amd.sqtt import map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC - from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4 + from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA ret:list[ProfileEvent] = [] row_ends:dict[str, Decimal] = {} curr_barrier:dict[str, ProfileRangeEvent] = {} NS_PER_TICK = 10 # 100MHz prev_pair:tuple[int, int]|None = None # (shader, realtime) + is_cdna = target.startswith("gfx9") def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None: row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}" # barrier on this row extends to fill the time our wave was waiting if (barrier:=curr_barrier.pop(row, None)) is not None: barrier.en = Decimal(p._time) ret.append(e:=ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+1))) - if (et:=row_ends.get(row)) is not None and e.st < et: raise RuntimeError(f"packet {p} overlaps another packet in {row}.") + # allow CDNA packets to overlap, NOT allowed on RDNA. + if (et:=row_ends.get(row)) is not None and e.st < et and not is_cdna: raise RuntimeError(f"packet {p} overlaps another packet in {row}.") row_ends[row] = unwrap(e.en) if name == "BARRIER": curr_barrier[row] = e for p, info in map_insts(data, lib, target): @@ -363,8 +365,8 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: freq_hz = (s1 - s0) * 1_000_000_000 // ((r1 - r0) * NS_PER_TICK) ret.append(ProfilePointEvent("LINE:Shader Clock", "freq_hz", freq_hz, ts=Decimal(p._time))) prev_pair = pair - if isinstance(p, (INST, INST_RDNA4)): - name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4)) else f"0x{p.op:02x}" + if isinstance(p, (INST, INST_RDNA4, CDNA_INST)): + name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4, InstOpCDNA)) else f"0x{p.op:02x}" add(name, p, info=info) if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info) if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info).wave, info=info)