add helper for test_timing custom ops (#13140)

This commit is contained in:
qazal
2025-11-07 17:13:55 +08:00
committed by GitHub
parent 95620426d5
commit 7e94369464

View File

@@ -18,15 +18,17 @@ from extra.sqtt.roc import decode, InstExec, PrgExec
dev = Device["AMD"]
def custom(arg:str, s:UOp|None=None) -> UOp: return UOp(Ops.CUSTOM, src=(s,) if s is not None else (), arg=arg)
def asm_kernel(instrs:list[str], l:int=1, g:int=1) -> Tensor:
name = sys._getframe(1).f_code.co_name
def fxn(_):
L = UOp.special(l, "lidx0")
G = UOp.special(g, "gidx0")
ops:list[str] = [UOp(Ops.CUSTOM, arg="asm volatile (")]
for inst in instrs: ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=f' "{inst}\\n\\t"'))
ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=");"))
return UOp.sink(*ops, L, G, arg=KernelInfo(name=name))
op = custom("asm volatile (")
for inst in instrs: op = custom(f' "{inst}\\n\\t"', op)
op = custom(");", op)
return UOp.sink(op, L, G, arg=KernelInfo(name=name))
k = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
return k
@@ -87,12 +89,11 @@ class TestTiming(unittest.TestCase):
n = 1
def sleep_kernel(data0):
assert data0.dtype.base == dtypes.ulong
ops:list[UOp] = []
ops.append(UOp(Ops.CUSTOM, arg="unsigned long long t0 = __builtin_readcyclecounter();"))
ops.append(UOp(Ops.CUSTOM, arg=f"__builtin_amdgcn_s_sleep({n});", src=(ops[-1],)))
ops.append(UOp(Ops.CUSTOM, arg="unsigned long long t1 = __builtin_readcyclecounter();", src=(ops[-1],)))
ops.append(UOp(Ops.CUSTOM, arg=f"data0_{data0.size}[0] = t1 - t0;", src=(ops[-1],)))
return UOp.sink(data0, *ops, arg=KernelInfo(name=f"sleep_{n}"))
op = custom("unsigned long long t0 = __builtin_readcyclecounter();")
op = custom(f"__builtin_amdgcn_s_sleep({n});", op)
op = custom(f"unsigned long long t1 = __builtin_readcyclecounter();", op)
op = custom(f"data0_{data0.size}[0] = t1 - t0;", op)
return UOp.sink(data0, op, arg=KernelInfo(name=f"sleep_{n}"))
diff_hw_reg = Tensor.empty(1, dtype=dtypes.ulong)
diff_hw_reg = Tensor.custom_kernel(diff_hw_reg, fxn=sleep_kernel)[0]
with save_sqtt() as sqtt: