sqtt: add CDNA ops enum, show in viz (#15140)

This commit is contained in:
qazal
2026-03-17 02:38:42 +02:00
committed by GitHub
parent 3e2b7803e6
commit 33bd33e783
5 changed files with 36 additions and 8 deletions

View File

@@ -78,8 +78,9 @@ class TestSQTTMapBase(unittest.TestCase):
for event in events:
if (p:=kern_events.get(event.kern)) is None: continue
with self.subTest(example=name, kern=event.kern):
# skip if there's no SQTT frequency data
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)): continue
frequency = [e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]
if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue
mean = sum(frequency) / len(frequency)
variance = sum((v - mean) ** 2 for v in frequency) / len(frequency)
self.assertGreater(mean, 0)
@@ -110,5 +111,9 @@ class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"
class TestSQTTMapCDNA(TestSQTTMapBase):
target = "gfx950"
def test_rocprof_inst_traces_match(self): self.skipTest("requires timestamp patching to match rocprof, currently it's off by a few cycles")
if __name__ == "__main__":
unittest.main()

View File

@@ -87,6 +87,9 @@ class TestGemmLarge(unittest.TestCase):
if not is_cdna4():
self.skipTest("very slow on non mi350x")
@Context(ASM_GEMM=1)
def test_empty(self): (Tensor.empty(N:=getenv("N", 4096), N, dtype=dtypes.half)@Tensor.empty(N, N, dtype=dtypes.half)).realize()
def test_tiny(self): verify_asm_gemm(1, 256, 256, 64)
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336)

View File

@@ -139,6 +139,24 @@ class InstOpRDNA4(Enum):
OTHER_VMEM = 0xbd
OTHER_VMEM_5 = 0xc1
class InstOpCDNA(Enum):
SMEM_RD = 0
SALU_32 = 1
VMEM_RD = 2
VMEM_WR = 3
FLAT_WR = 4
VALU_32 = 5
LDS = 6
PC = 7
JUMP = 12
NEXT = 13
FLAT_RD = 14
OTHER_MSG = 15
SMEM_WR = 16
SALU_64 = 17
VALU_64 = 18
VALU_MAI = 28
# ═══════════════════════════════════════════════════════════════════════════════
# PACKET TYPE BASE CLASS
# ═══════════════════════════════════════════════════════════════════════════════
@@ -448,7 +466,7 @@ class CDNA_INST(PacketType):
encoding = bits[3:0] == 10
wave = bits[8:5]
simd = bits[10:9]
inst_type = bits[15:11]
op = bits[15:11].enum(InstOpCDNA)
class CDNA_INST_PC(PacketType):
"""pkt_fmt=11: 64-bit (MsgInstPc)"""
@@ -612,7 +630,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
def simd_select(p) -> bool: return getattr(p, "cu", 0) == 0 and getattr(p, "simd", 0) == 0
for p in decode(data):
if not simd_select(p): continue
if isinstance(p, (WAVESTART, WAVESTART_RDNA4)):
if isinstance(p, (WAVESTART, WAVESTART_RDNA4, CDNA_WAVESTART)):
assert p.wave not in wave_pc, "only one inflight wave per unit"
wave_pc[p.wave] = next(iter(pc_map))
elif isinstance(p, WAVEEND):

View File

@@ -252,7 +252,7 @@ function selectShape(key) {
const timelineScale = () => d3.scaleLinear().domain([data.first, data.dur]).range([0, document.getElementById("timeline").clientWidth])
function timeAtCycle(clk) {
if (clk < data.instSt || clk > data.instEt) return "-";
if (clk < data.instSt || clk > data.instEt || data.tracks.get("Shader Clock") == null) return "-";
let cur = data.instSt, ns = 0, freq = null;
// walk through all frequency changes and accumulate time in nanoseconds
for (const [s, v] of data.tracks.get("Shader Clock").valueMap) {

View File

@@ -339,18 +339,20 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
from tinygrad.renderer.amd.sqtt import map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC
from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4
from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA
ret:list[ProfileEvent] = []
row_ends:dict[str, Decimal] = {}
curr_barrier:dict[str, ProfileRangeEvent] = {}
NS_PER_TICK = 10 # 100MHz
prev_pair:tuple[int, int]|None = None # (shader, realtime)
is_cdna = target.startswith("gfx9")
def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None:
row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}"
# barrier on this row extends to fill the time our wave was waiting
if (barrier:=curr_barrier.pop(row, None)) is not None: barrier.en = Decimal(p._time)
ret.append(e:=ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+1)))
if (et:=row_ends.get(row)) is not None and e.st < et: raise RuntimeError(f"packet {p} overlaps another packet in {row}.")
# allow CDNA packets to overlap, NOT allowed on RDNA.
if (et:=row_ends.get(row)) is not None and e.st < et and not is_cdna: raise RuntimeError(f"packet {p} overlaps another packet in {row}.")
row_ends[row] = unwrap(e.en)
if name == "BARRIER": curr_barrier[row] = e
for p, info in map_insts(data, lib, target):
@@ -363,8 +365,8 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
freq_hz = (s1 - s0) * 1_000_000_000 // ((r1 - r0) * NS_PER_TICK)
ret.append(ProfilePointEvent("LINE:Shader Clock", "freq_hz", freq_hz, ts=Decimal(p._time)))
prev_pair = pair
if isinstance(p, (INST, INST_RDNA4)):
name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4)) else f"0x{p.op:02x}"
if isinstance(p, (INST, INST_RDNA4, CDNA_INST)):
name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4, InstOpCDNA)) else f"0x{p.op:02x}"
add(name, p, info=info)
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info)
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info).wave, info=info)