qol improvements to sqtt decoder and timing tests (#13125)

This commit is contained in:
qazal
2025-11-06 20:51:30 +08:00
committed by GitHub
parent dafdb4bfb1
commit 88245d6579
3 changed files with 59 additions and 38 deletions

View File

@@ -48,13 +48,21 @@ class InstExec:
dur:int
time:int
@dataclasses.dataclass(frozen=True)
class PrgExec:
name:str
wave:int
cu:int
simd:int
def __str__(self): return f"{self.name},{self.wave},{self.cu},{self.simd}"
class _ROCParseCtx:
def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]):
self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs
self.wave_events:dict[tuple[str, int, int, int], dict[int, InstInfo]] = {}
self.wave_events:dict[PrgExec, dict[int, InstInfo]] = {}
self.disasms:dict[int, tuple[str, int]] = {}
self.addr2prg:dict[int, ProfileProgramEvent] = {}
self.inst_execs:dict[tuple[str, int, int, int], list[InstExec]] = {}
self.inst_execs:dict[PrgExec, list[InstExec]] = {}
for prog in prog_evs:
arch = "gfx%d%x%x" % ((trgt:=unwrap(dev_evs[prog.device].props)['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
@@ -86,7 +94,7 @@ class _ROCParseCtx:
inst_execs.append(InstExec(inst_typ, inst_disasm, inst_ev.stall, inst_ev.duration, inst_ev.time))
if ev.instructions_size > 0:
self.wave_events[key:=(self.find_program(ev.instructions_array[0].pc.address).name, ev.wave_id, ev.cu, ev.simd)] = asm
self.wave_events[key:=PrgExec(self.find_program(ev.instructions_array[0].pc.address).name, ev.wave_id, ev.cu, ev.simd)] = asm
self.inst_execs[key] = inst_execs
def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:

View File

@@ -6,7 +6,7 @@ os.environ["VIZ"] = "1"
os.environ["AMD_LLVM"] = "0"
import unittest
import sys
import sys, contextlib
from tinygrad import Tensor
from tinygrad.dtype import dtypes
from tinygrad.renderer import ProgramSpec
@@ -14,60 +14,73 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.engine.realize import CompiledRunner
from tinygrad.device import Device, ProfileDeviceEvent
from extra.sqtt.roc import decode, InstExec
from extra.sqtt.roc import decode, InstExec, PrgExec
dev = Device["AMD"]
def get_sqtt(asm:list[str], l:int=1, g:int=1) -> list[InstExec]:
# clear the old traces
dev.profile_events.clear()
# setup custom_kernel
def asm_kernel(instrs:list[str], l:int=1, g:int=1) -> Tensor:
name = sys._getframe(1).f_code.co_name
def fxn(_):
L = UOp.special(l, "lidx0")
G = UOp.special(g, "gidx0")
ops:list[str] = [UOp(Ops.CUSTOM, arg="asm volatile (")]
for inst in asm: ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=f' "{inst}\\n\\t"'))
for inst in instrs: ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=f' "{inst}\\n\\t"'))
ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=");"))
return UOp.sink(*ops, L, G, arg=KernelInfo(name=name))
k = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
# exec and decode sqtt
k.realize()
return k
@contextlib.contextmanager
def save_sqtt():
# clear the old traces
dev.profile_events.clear()
sqtt:dict[PrgExec, list[InstExec]] = {}
yield sqtt
# decode sqtt
rctx = decode(dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())])
assert len(rctx.inst_execs) > 0, "empty sqtt output"
return list(rctx.inst_execs.values())[0][:-1]
sqtt.update(rctx.inst_execs)
class TestTiming(unittest.TestCase):
def test_v_add(self):
sqtt = get_sqtt([f"v_add_f32 v{10+i} v{10+i+1} {10+i}" for i in range(3)])
assert all(s.dur == 1 for s in sqtt)
assert all(s.stall == 0 for s in sqtt)
with save_sqtt() as sqtt:
asm_kernel([f"v_add_f32 v{10+i} v{10+i+1} {10+i}" for i in range(3)]).realize()
wave = list(sqtt.values())[0][:-1]
assert all(s.dur == 1 for s in wave)
assert all(s.stall == 0 for s in wave)
def test_chain_v_add_1l(self):
sqtt = get_sqtt([
"v_add_f32_e32 v1 v0 v0",
"v_add_f32_e32 v2 v1 v1",
])
assert all(s.dur == 1 for s in sqtt)
assert all(s.stall == 0 for s in sqtt)
with save_sqtt() as sqtt:
asm_kernel([
"v_add_f32_e32 v1 v0 v0",
"v_add_f32_e32 v2 v1 v1",
]).realize()
wave = list(sqtt.values())[0][:-1]
assert all(s.dur == 1 for s in wave)
assert all(s.stall == 0 for s in wave)
def test_multi_cycle_inst(self):
sqtt = get_sqtt([
"v_mov_b32_e32 v4 0x3f800000",
"v_rcp_f32_e32 v5 v4",
"v_mul_f32_e32 v6 v5 v4",
])
rcp, mul = sqtt[1], sqtt[2]
with save_sqtt() as sqtt:
asm_kernel([
"v_mov_b32_e32 v4 0x3f800000",
"v_rcp_f32_e32 v5 v4",
"v_mul_f32_e32 v6 v5 v4",
]).realize()
w = list(sqtt.values())[0]
rcp, mul = w[1], w[2]
self.assertGreater(rcp.dur, 1) # 4 cycles on gfx11
self.assertEqual(mul.dur, 1)
# mul depends on v5, how can it run before rcp is done?
self.assertGreaterEqual(mul.time, rcp.time+rcp.dur)
def test_wmma(self):
sqtt = get_sqtt([
"v_wmma_f32_16x16x16_f16 v[16:23], v[0:7], v[8:15], v[16:23]",
"v_add_f32_e32 v0 v16 v0",
], 32*4)
wmma = sqtt[0]
with save_sqtt() as sqtt:
asm_kernel([
"v_wmma_f32_16x16x16_f16 v[16:23], v[0:7], v[8:15], v[16:23]",
"v_add_f32_e32 v0 v16 v0",
], l=32*4).realize()
assert len(sqtt) == 2, f"expected two waves, got {len(sqtt)} {list(sqtt.keys())}"
wmma = list(sqtt.values())[0][0]
self.assertGreater(wmma.dur, 1) # rgp says 32 clocks
def test_sleep(self):
@@ -82,9 +95,9 @@ class TestTiming(unittest.TestCase):
return UOp.sink(data0, *ops, arg=KernelInfo(name=f"sleep_{n}"))
diff_hw_reg = Tensor.empty(1, dtype=dtypes.ulong)
diff_hw_reg = Tensor.custom_kernel(diff_hw_reg, fxn=sleep_kernel)[0]
diff_hw_reg.realize()
rctx = decode(dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())])
diff_sqtt = list(rctx.inst_execs.values())[0][2]
with save_sqtt() as sqtt:
diff_hw_reg.realize()
diff_sqtt = list(sqtt.values())[0][2]
self.assertEqual(diff_sqtt.dur, diff_hw_reg.item()-1) # 1 cycle for reading the counter register
if __name__ == "__main__":

View File

@@ -203,8 +203,8 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
except Exception: return err("DECODER IMPORT ISSUE")
try:
rctx = decode(profile)
steps = [{"name":x[0], "depth":0, "data":{"rows":[(e.inst, e.time, e.time-x[1][i-1].time if i else 0, e.dur, e.stall, str(e.typ).split("_")[-1])
for i,e in enumerate(x[1])],
steps = [{"name":str(x[0]), "depth":0, "data":{"rows":[(e.inst, e.time, e.time-x[1][i-1].time if i else 0, e.dur, e.stall,
str(e.typ).split("_")[-1]) for i,e in enumerate(x[1])],
"cols":["Instruction", "Clk", "Wait", "Duration", "Stall", "Type"], "summary":[]},
"query":f"/render?ctx={len(ctxs)}&step={i}&fmt=counters"} for i,x in enumerate(rctx.inst_execs.items())]
if not steps: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")