mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
sqtt: update rcp timing test (#13231)
* sqtt: assert correct output in timing test * found why
This commit is contained in:
@@ -10,7 +10,7 @@ import sys, contextlib
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AddrSpace
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.device import Device, ProfileDeviceEvent
|
||||
|
||||
@@ -63,18 +63,24 @@ class TestTiming(unittest.TestCase):
|
||||
assert all(s.stall == 0 for s in wave)
|
||||
|
||||
def test_multi_cycle_inst(self):
|
||||
def custom_vrcp(A, B):
|
||||
op = custom("float a = 0.0;")
|
||||
op = custom("float b = (*(data1_1+0));", op)
|
||||
#op = custom('asm volatile("v_mul_f32_e32 %2 %2 %1" : "+v"(a) : "v"(b));', op)
|
||||
op = custom('asm volatile("v_rcp_f32_e32 %2 %1" : "+v"(a) : "v"(b));', op)
|
||||
op = custom('asm volatile("v_add_f32_e64 %1 %1 1.0" : "+v"(a));', op)
|
||||
op = custom("*(data0_1+0) = a;", op)
|
||||
return UOp.sink(op, A, B, arg=KernelInfo(name="custom_vrcp"))
|
||||
out = Tensor([0.]).realize()
|
||||
inp = Tensor([-2.0]).realize()
|
||||
with save_sqtt() as sqtt:
|
||||
asm_kernel([
|
||||
"v_mov_b32_e32 v4 0x3f800000",
|
||||
"v_rcp_f32_e32 v5 v4",
|
||||
"v_mul_f32_e32 v6 v5 v4",
|
||||
]).realize()
|
||||
w = list(sqtt.values())[0]
|
||||
rcp, mul = w[1], w[2]
|
||||
self.assertGreater(rcp.dur, 1) # 4 cycles on gfx11
|
||||
self.assertEqual(mul.dur, 1)
|
||||
# mul depends on v5, how can it run before rcp is done?
|
||||
self.assertGreaterEqual(mul.time, rcp.time+rcp.dur)
|
||||
Tensor.custom_kernel(out, inp, fxn=custom_vrcp)[0].realize()
|
||||
|
||||
wave = list(sqtt.values())[0][0]
|
||||
for i in range(len(wave.insts)):
|
||||
if wave.insts[i].inst.startswith("global_store"):
|
||||
print(f"store diff {wave.insts[i].time-(wave.insts[i-1].time)}")
|
||||
self.assertEqual(out.item(), 0.5)
|
||||
|
||||
def test_wmma(self):
|
||||
with save_sqtt() as sqtt:
|
||||
|
||||
Reference in New Issue
Block a user