mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
qol improvements to sqtt decoder and timing tests (#13125)
This commit is contained in:
@@ -48,13 +48,21 @@ class InstExec:
|
||||
dur:int
|
||||
time:int
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PrgExec:
|
||||
name:str
|
||||
wave:int
|
||||
cu:int
|
||||
simd:int
|
||||
def __str__(self): return f"{self.name},{self.wave},{self.cu},{self.simd}"
|
||||
|
||||
class _ROCParseCtx:
|
||||
def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]):
|
||||
self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs
|
||||
self.wave_events:dict[tuple[str, int, int, int], dict[int, InstInfo]] = {}
|
||||
self.wave_events:dict[PrgExec, dict[int, InstInfo]] = {}
|
||||
self.disasms:dict[int, tuple[str, int]] = {}
|
||||
self.addr2prg:dict[int, ProfileProgramEvent] = {}
|
||||
self.inst_execs:dict[tuple[str, int, int, int], list[InstExec]] = {}
|
||||
self.inst_execs:dict[PrgExec, list[InstExec]] = {}
|
||||
|
||||
for prog in prog_evs:
|
||||
arch = "gfx%d%x%x" % ((trgt:=unwrap(dev_evs[prog.device].props)['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
|
||||
@@ -86,7 +94,7 @@ class _ROCParseCtx:
|
||||
inst_execs.append(InstExec(inst_typ, inst_disasm, inst_ev.stall, inst_ev.duration, inst_ev.time))
|
||||
|
||||
if ev.instructions_size > 0:
|
||||
self.wave_events[key:=(self.find_program(ev.instructions_array[0].pc.address).name, ev.wave_id, ev.cu, ev.simd)] = asm
|
||||
self.wave_events[key:=PrgExec(self.find_program(ev.instructions_array[0].pc.address).name, ev.wave_id, ev.cu, ev.simd)] = asm
|
||||
self.inst_execs[key] = inst_execs
|
||||
|
||||
def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
|
||||
|
||||
@@ -6,7 +6,7 @@ os.environ["VIZ"] = "1"
|
||||
os.environ["AMD_LLVM"] = "0"
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import sys, contextlib
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
@@ -14,60 +14,73 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.device import Device, ProfileDeviceEvent
|
||||
|
||||
from extra.sqtt.roc import decode, InstExec
|
||||
from extra.sqtt.roc import decode, InstExec, PrgExec
|
||||
|
||||
dev = Device["AMD"]
|
||||
def get_sqtt(asm:list[str], l:int=1, g:int=1) -> list[InstExec]:
|
||||
# clear the old traces
|
||||
dev.profile_events.clear()
|
||||
# setup custom_kernel
|
||||
|
||||
def asm_kernel(instrs:list[str], l:int=1, g:int=1) -> Tensor:
|
||||
name = sys._getframe(1).f_code.co_name
|
||||
def fxn(_):
|
||||
L = UOp.special(l, "lidx0")
|
||||
G = UOp.special(g, "gidx0")
|
||||
ops:list[str] = [UOp(Ops.CUSTOM, arg="asm volatile (")]
|
||||
for inst in asm: ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=f' "{inst}\\n\\t"'))
|
||||
for inst in instrs: ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=f' "{inst}\\n\\t"'))
|
||||
ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=");"))
|
||||
return UOp.sink(*ops, L, G, arg=KernelInfo(name=name))
|
||||
k = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
|
||||
# exec and decode sqtt
|
||||
k.realize()
|
||||
return k
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_sqtt():
|
||||
# clear the old traces
|
||||
dev.profile_events.clear()
|
||||
sqtt:dict[PrgExec, list[InstExec]] = {}
|
||||
yield sqtt
|
||||
# decode sqtt
|
||||
rctx = decode(dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())])
|
||||
assert len(rctx.inst_execs) > 0, "empty sqtt output"
|
||||
return list(rctx.inst_execs.values())[0][:-1]
|
||||
sqtt.update(rctx.inst_execs)
|
||||
|
||||
class TestTiming(unittest.TestCase):
|
||||
def test_v_add(self):
|
||||
sqtt = get_sqtt([f"v_add_f32 v{10+i} v{10+i+1} {10+i}" for i in range(3)])
|
||||
assert all(s.dur == 1 for s in sqtt)
|
||||
assert all(s.stall == 0 for s in sqtt)
|
||||
with save_sqtt() as sqtt:
|
||||
asm_kernel([f"v_add_f32 v{10+i} v{10+i+1} {10+i}" for i in range(3)]).realize()
|
||||
wave = list(sqtt.values())[0][:-1]
|
||||
assert all(s.dur == 1 for s in wave)
|
||||
assert all(s.stall == 0 for s in wave)
|
||||
|
||||
def test_chain_v_add_1l(self):
|
||||
sqtt = get_sqtt([
|
||||
"v_add_f32_e32 v1 v0 v0",
|
||||
"v_add_f32_e32 v2 v1 v1",
|
||||
])
|
||||
assert all(s.dur == 1 for s in sqtt)
|
||||
assert all(s.stall == 0 for s in sqtt)
|
||||
with save_sqtt() as sqtt:
|
||||
asm_kernel([
|
||||
"v_add_f32_e32 v1 v0 v0",
|
||||
"v_add_f32_e32 v2 v1 v1",
|
||||
]).realize()
|
||||
wave = list(sqtt.values())[0][:-1]
|
||||
assert all(s.dur == 1 for s in wave)
|
||||
assert all(s.stall == 0 for s in wave)
|
||||
|
||||
def test_multi_cycle_inst(self):
|
||||
sqtt = get_sqtt([
|
||||
"v_mov_b32_e32 v4 0x3f800000",
|
||||
"v_rcp_f32_e32 v5 v4",
|
||||
"v_mul_f32_e32 v6 v5 v4",
|
||||
])
|
||||
rcp, mul = sqtt[1], sqtt[2]
|
||||
with save_sqtt() as sqtt:
|
||||
asm_kernel([
|
||||
"v_mov_b32_e32 v4 0x3f800000",
|
||||
"v_rcp_f32_e32 v5 v4",
|
||||
"v_mul_f32_e32 v6 v5 v4",
|
||||
]).realize()
|
||||
w = list(sqtt.values())[0]
|
||||
rcp, mul = w[1], w[2]
|
||||
self.assertGreater(rcp.dur, 1) # 4 cycles on gfx11
|
||||
self.assertEqual(mul.dur, 1)
|
||||
# mul depends on v5, how can it run before rcp is done?
|
||||
self.assertGreaterEqual(mul.time, rcp.time+rcp.dur)
|
||||
|
||||
def test_wmma(self):
|
||||
sqtt = get_sqtt([
|
||||
"v_wmma_f32_16x16x16_f16 v[16:23], v[0:7], v[8:15], v[16:23]",
|
||||
"v_add_f32_e32 v0 v16 v0",
|
||||
], 32*4)
|
||||
wmma = sqtt[0]
|
||||
with save_sqtt() as sqtt:
|
||||
asm_kernel([
|
||||
"v_wmma_f32_16x16x16_f16 v[16:23], v[0:7], v[8:15], v[16:23]",
|
||||
"v_add_f32_e32 v0 v16 v0",
|
||||
], l=32*4).realize()
|
||||
assert len(sqtt) == 2, f"expected two waves, got {len(sqtt)} {list(sqtt.keys())}"
|
||||
wmma = list(sqtt.values())[0][0]
|
||||
self.assertGreater(wmma.dur, 1) # rgp says 32 clocks
|
||||
|
||||
def test_sleep(self):
|
||||
@@ -82,9 +95,9 @@ class TestTiming(unittest.TestCase):
|
||||
return UOp.sink(data0, *ops, arg=KernelInfo(name=f"sleep_{n}"))
|
||||
diff_hw_reg = Tensor.empty(1, dtype=dtypes.ulong)
|
||||
diff_hw_reg = Tensor.custom_kernel(diff_hw_reg, fxn=sleep_kernel)[0]
|
||||
diff_hw_reg.realize()
|
||||
rctx = decode(dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())])
|
||||
diff_sqtt = list(rctx.inst_execs.values())[0][2]
|
||||
with save_sqtt() as sqtt:
|
||||
diff_hw_reg.realize()
|
||||
diff_sqtt = list(sqtt.values())[0][2]
|
||||
self.assertEqual(diff_sqtt.dur, diff_hw_reg.item()-1) # 1 cycle for reading the counter register
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -203,8 +203,8 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
except Exception: return err("DECODER IMPORT ISSUE")
|
||||
try:
|
||||
rctx = decode(profile)
|
||||
steps = [{"name":x[0], "depth":0, "data":{"rows":[(e.inst, e.time, e.time-x[1][i-1].time if i else 0, e.dur, e.stall, str(e.typ).split("_")[-1])
|
||||
for i,e in enumerate(x[1])],
|
||||
steps = [{"name":str(x[0]), "depth":0, "data":{"rows":[(e.inst, e.time, e.time-x[1][i-1].time if i else 0, e.dur, e.stall,
|
||||
str(e.typ).split("_")[-1]) for i,e in enumerate(x[1])],
|
||||
"cols":["Instruction", "Clk", "Wait", "Duration", "Stall", "Type"], "summary":[]},
|
||||
"query":f"/render?ctx={len(ctxs)}&step={i}&fmt=counters"} for i,x in enumerate(rctx.inst_execs.items())]
|
||||
if not steps: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
|
||||
|
||||
Reference in New Issue
Block a user