mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
sqtt: add CDNA ops enum, show in viz (#15140)
This commit is contained in:
@@ -78,8 +78,9 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
for event in events:
|
||||
if (p:=kern_events.get(event.kern)) is None: continue
|
||||
with self.subTest(example=name, kern=event.kern):
|
||||
# skip if there's no SQTT frequency data
|
||||
if not (timeline:=sqtt_timeline(event.blob, p.lib, target)): continue
|
||||
frequency = [e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]
|
||||
if not (frequency:=[e.key for e in timeline if type(e).__name__ == "ProfilePointEvent" and e.name == "freq_hz"]): continue
|
||||
mean = sum(frequency) / len(frequency)
|
||||
variance = sum((v - mean) ** 2 for v in frequency) / len(frequency)
|
||||
self.assertGreater(mean, 0)
|
||||
@@ -110,5 +111,9 @@ class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
|
||||
|
||||
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"
|
||||
|
||||
class TestSQTTMapCDNA(TestSQTTMapBase):
|
||||
target = "gfx950"
|
||||
def test_rocprof_inst_traces_match(self): self.skipTest("requires timestamp patching to match rocprof, currently it's off by a few cycles")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -87,6 +87,9 @@ class TestGemmLarge(unittest.TestCase):
|
||||
if not is_cdna4():
|
||||
self.skipTest("very slow on non mi350x")
|
||||
|
||||
@Context(ASM_GEMM=1)
|
||||
def test_empty(self): (Tensor.empty(N:=getenv("N", 4096), N, dtype=dtypes.half)@Tensor.empty(N, N, dtype=dtypes.half)).realize()
|
||||
|
||||
def test_tiny(self): verify_asm_gemm(1, 256, 256, 64)
|
||||
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
|
||||
def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336)
|
||||
|
||||
@@ -139,6 +139,24 @@ class InstOpRDNA4(Enum):
|
||||
OTHER_VMEM = 0xbd
|
||||
OTHER_VMEM_5 = 0xc1
|
||||
|
||||
class InstOpCDNA(Enum):
|
||||
SMEM_RD = 0
|
||||
SALU_32 = 1
|
||||
VMEM_RD = 2
|
||||
VMEM_WR = 3
|
||||
FLAT_WR = 4
|
||||
VALU_32 = 5
|
||||
LDS = 6
|
||||
PC = 7
|
||||
JUMP = 12
|
||||
NEXT = 13
|
||||
FLAT_RD = 14
|
||||
OTHER_MSG = 15
|
||||
SMEM_WR = 16
|
||||
SALU_64 = 17
|
||||
VALU_64 = 18
|
||||
VALU_MAI = 28
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PACKET TYPE BASE CLASS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -448,7 +466,7 @@ class CDNA_INST(PacketType):
|
||||
encoding = bits[3:0] == 10
|
||||
wave = bits[8:5]
|
||||
simd = bits[10:9]
|
||||
inst_type = bits[15:11]
|
||||
op = bits[15:11].enum(InstOpCDNA)
|
||||
|
||||
class CDNA_INST_PC(PacketType):
|
||||
"""pkt_fmt=11: 64-bit (MsgInstPc)"""
|
||||
@@ -612,7 +630,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
|
||||
def simd_select(p) -> bool: return getattr(p, "cu", 0) == 0 and getattr(p, "simd", 0) == 0
|
||||
for p in decode(data):
|
||||
if not simd_select(p): continue
|
||||
if isinstance(p, (WAVESTART, WAVESTART_RDNA4)):
|
||||
if isinstance(p, (WAVESTART, WAVESTART_RDNA4, CDNA_WAVESTART)):
|
||||
assert p.wave not in wave_pc, "only one inflight wave per unit"
|
||||
wave_pc[p.wave] = next(iter(pc_map))
|
||||
elif isinstance(p, WAVEEND):
|
||||
|
||||
@@ -252,7 +252,7 @@ function selectShape(key) {
|
||||
const timelineScale = () => d3.scaleLinear().domain([data.first, data.dur]).range([0, document.getElementById("timeline").clientWidth])
|
||||
|
||||
function timeAtCycle(clk) {
|
||||
if (clk < data.instSt || clk > data.instEt) return "-";
|
||||
if (clk < data.instSt || clk > data.instEt || data.tracks.get("Shader Clock") == null) return "-";
|
||||
let cur = data.instSt, ns = 0, freq = null;
|
||||
// walk through all frequency changes and accumulate time in nanoseconds
|
||||
for (const [s, v] of data.tracks.get("Shader Clock").valueMap) {
|
||||
|
||||
@@ -339,18 +339,20 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
|
||||
|
||||
def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
||||
from tinygrad.renderer.amd.sqtt import map_insts, InstructionInfo, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC
|
||||
from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4
|
||||
from tinygrad.renderer.amd.sqtt import INST_RDNA4, InstOpRDNA4, TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4, CDNA_INST, InstOpCDNA
|
||||
ret:list[ProfileEvent] = []
|
||||
row_ends:dict[str, Decimal] = {}
|
||||
curr_barrier:dict[str, ProfileRangeEvent] = {}
|
||||
NS_PER_TICK = 10 # 100MHz
|
||||
prev_pair:tuple[int, int]|None = None # (shader, realtime)
|
||||
is_cdna = target.startswith("gfx9")
|
||||
def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None:
|
||||
row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}"
|
||||
# barrier on this row extends to fill the time our wave was waiting
|
||||
if (barrier:=curr_barrier.pop(row, None)) is not None: barrier.en = Decimal(p._time)
|
||||
ret.append(e:=ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+1)))
|
||||
if (et:=row_ends.get(row)) is not None and e.st < et: raise RuntimeError(f"packet {p} overlaps another packet in {row}.")
|
||||
# allow CDNA packets to overlap, NOT allowed on RDNA.
|
||||
if (et:=row_ends.get(row)) is not None and e.st < et and not is_cdna: raise RuntimeError(f"packet {p} overlaps another packet in {row}.")
|
||||
row_ends[row] = unwrap(e.en)
|
||||
if name == "BARRIER": curr_barrier[row] = e
|
||||
for p, info in map_insts(data, lib, target):
|
||||
@@ -363,8 +365,8 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
|
||||
freq_hz = (s1 - s0) * 1_000_000_000 // ((r1 - r0) * NS_PER_TICK)
|
||||
ret.append(ProfilePointEvent("LINE:Shader Clock", "freq_hz", freq_hz, ts=Decimal(p._time)))
|
||||
prev_pair = pair
|
||||
if isinstance(p, (INST, INST_RDNA4)):
|
||||
name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4)) else f"0x{p.op:02x}"
|
||||
if isinstance(p, (INST, INST_RDNA4, CDNA_INST)):
|
||||
name = p.op.name if isinstance(p.op, (InstOp, InstOpRDNA4, InstOpCDNA)) else f"0x{p.op:02x}"
|
||||
add(name, p, info=info)
|
||||
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info)
|
||||
if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info).wave, info=info)
|
||||
|
||||
Reference in New Issue
Block a user