sqtt test_timing work (#13304)

* sqtt test_timing cleanups

* only the instruction

* v_mfma_f32_16x16x32_f16 16 cycles, only after second one though
This commit is contained in:
qazal
2025-11-16 23:49:24 +08:00
committed by GitHub
parent 8f0e747b3a
commit c70b06ec19

View File

@@ -39,10 +39,10 @@ def save_sqtt():
sqtt:dict[str, list[WaveExec]] = {}
yield sqtt
# decode sqtt
if os.environ["DEV"] == "AMD":
rctx = decode(dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())])
assert len(rctx.inst_execs) > 0, "empty sqtt output"
sqtt.update(rctx.inst_execs)
if os.environ["DEV"] != "AMD": return
rctx = decode(dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())])
assert len(rctx.inst_execs) > 0, "empty sqtt output"
sqtt.update(rctx.inst_execs)
class TestTiming(unittest.TestCase):
def test_v_add(self):
@@ -84,13 +84,17 @@ class TestTiming(unittest.TestCase):
def test_wmma(self):
with save_sqtt() as sqtt:
asm_kernel([
"v_wmma_f32_16x16x16_f16 v[16:23], v[0:7], v[8:15], v[16:23]",
"v_add_f32_e32 v0 v16 v0",
], l=32*4).realize()
assert len(sqtt) == 2, f"expected two waves, got {len(sqtt)} {list(sqtt.keys())}"
wmma = list(sqtt.values())[0][0]
self.assertGreater(wmma.dur, 1) # rgp says 32 clocks
for tc in dev.renderer.get_tensor_cores(dev.arch):
M, K, N = tc.dims
s = 32
a = Tensor.empty(M*s, K*s, dtype=tc.dtype_in)@Tensor.empty(K*s, N*s, dtype=tc.dtype_in)
a.realize()
print(a)
for p,waves in sqtt.items():
for e in waves[0].insts:
if (e.inst.startswith("v_wmma")):
instruction = e.inst.split(" ")[0]
print(f"{instruction:<29} : {e.dur} cycles")
def test_sleep(self):
n = 1
@@ -105,8 +109,9 @@ class TestTiming(unittest.TestCase):
diff_hw_reg = Tensor.custom_kernel(diff_hw_reg, fxn=sleep_kernel)[0]
with save_sqtt() as sqtt:
diff_hw_reg.realize()
diff_sqtt = list(sqtt.values())[0][2]
self.assertEqual(diff_sqtt.dur, diff_hw_reg.item()-1) # 1 cycle for reading the counter register
sleep = next((e for e in sqtt[f"sleep_{n}"][0].insts if e.inst.startswith("s_sleep")))
# cycles = sleep dur + overhead of storing hi/lo REG_SHADER_CYCLES
self.assertGreaterEqual(diff_hw_reg.item(), sleep.dur)
if __name__ == "__main__":
unittest.main()