From 8b26cf2b3d784b11b3b5e6c3114a7c43c101c4cd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 13 Nov 2025 02:01:54 +0800 Subject: [PATCH] sqtt: update rcp timing test (#13231) * sqtt: assert correct output in timing test * found why --- extra/sqtt/test_timing.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/extra/sqtt/test_timing.py b/extra/sqtt/test_timing.py index 8c6feda2ae..f0f51516c3 100644 --- a/extra/sqtt/test_timing.py +++ b/extra/sqtt/test_timing.py @@ -10,7 +10,7 @@ import sys, contextlib from tinygrad import Tensor from tinygrad.dtype import dtypes from tinygrad.renderer import ProgramSpec -from tinygrad.uop.ops import UOp, Ops, KernelInfo +from tinygrad.uop.ops import UOp, Ops, KernelInfo, AddrSpace from tinygrad.engine.realize import CompiledRunner from tinygrad.device import Device, ProfileDeviceEvent @@ -63,18 +63,24 @@ class TestTiming(unittest.TestCase): assert all(s.stall == 0 for s in wave) def test_multi_cycle_inst(self): + def custom_vrcp(A, B): + op = custom("float a = 0.0;") + op = custom("float b = (*(data1_1+0));", op) + #op = custom('asm volatile("v_mul_f32_e32 %2 %2 %1" : "+v"(a) : "v"(b));', op) + op = custom('asm volatile("v_rcp_f32_e32 %2 %1" : "+v"(a) : "v"(b));', op) + op = custom('asm volatile("v_add_f32_e64 %1 %1 1.0" : "+v"(a));', op) + op = custom("*(data0_1+0) = a;", op) + return UOp.sink(op, A, B, arg=KernelInfo(name="custom_vrcp")) + out = Tensor([0.]).realize() + inp = Tensor([-2.0]).realize() with save_sqtt() as sqtt: - asm_kernel([ - "v_mov_b32_e32 v4 0x3f800000", - "v_rcp_f32_e32 v5 v4", - "v_mul_f32_e32 v6 v5 v4", - ]).realize() - w = list(sqtt.values())[0] - rcp, mul = w[1], w[2] - self.assertGreater(rcp.dur, 1) # 4 cycles on gfx11 - self.assertEqual(mul.dur, 1) - # mul depends on v5, how can it run before rcp is done? - self.assertGreaterEqual(mul.time, rcp.time+rcp.dur) + Tensor.custom_kernel(out, inp, fxn=custom_vrcp)[0].realize() + + wave = list(sqtt.values())[0][0] + for i in range(len(wave.insts)): + if wave.insts[i].inst.startswith("global_store"): + print(f"store diff {wave.insts[i].time-(wave.insts[i-1].time)}") + self.assertEqual(out.item(), 0.5) def test_wmma(self): with save_sqtt() as sqtt: