diff --git a/extra/sqtt/test_timing.py b/extra/sqtt/test_timing.py index f412ceda3e..e9a16c6a35 100644 --- a/extra/sqtt/test_timing.py +++ b/extra/sqtt/test_timing.py @@ -18,15 +18,17 @@ from extra.sqtt.roc import decode, InstExec, PrgExec dev = Device["AMD"] +def custom(arg:str, s:UOp|None=None) -> UOp: return UOp(Ops.CUSTOM, src=(s,) if s is not None else (), arg=arg) + def asm_kernel(instrs:list[str], l:int=1, g:int=1) -> Tensor: name = sys._getframe(1).f_code.co_name def fxn(_): L = UOp.special(l, "lidx0") G = UOp.special(g, "gidx0") - ops:list[str] = [UOp(Ops.CUSTOM, arg="asm volatile (")] - for inst in instrs: ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=f' "{inst}\\n\\t"')) - ops.append(UOp(Ops.CUSTOM, src=(ops[-1],), arg=");")) - return UOp.sink(*ops, L, G, arg=KernelInfo(name=name)) + op = custom("asm volatile (") + for inst in instrs: op = custom(f' "{inst}\\n\\t"', op) + op = custom(");", op) + return UOp.sink(op, L, G, arg=KernelInfo(name=name)) k = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0] return k @@ -87,12 +89,11 @@ class TestTiming(unittest.TestCase): n = 1 def sleep_kernel(data0): assert data0.dtype.base == dtypes.ulong - ops:list[UOp] = [] - ops.append(UOp(Ops.CUSTOM, arg="unsigned long long t0 = __builtin_readcyclecounter();")) - ops.append(UOp(Ops.CUSTOM, arg=f"__builtin_amdgcn_s_sleep({n});", src=(ops[-1],))) - ops.append(UOp(Ops.CUSTOM, arg="unsigned long long t1 = __builtin_readcyclecounter();", src=(ops[-1],))) - ops.append(UOp(Ops.CUSTOM, arg=f"data0_{data0.size}[0] = t1 - t0;", src=(ops[-1],))) - return UOp.sink(data0, *ops, arg=KernelInfo(name=f"sleep_{n}")) + op = custom("unsigned long long t0 = __builtin_readcyclecounter();") + op = custom(f"__builtin_amdgcn_s_sleep({n});", op) + op = custom(f"unsigned long long t1 = __builtin_readcyclecounter();", op) + op = custom(f"data0_{data0.size}[0] = t1 - t0;", op) + return UOp.sink(data0, op, arg=KernelInfo(name=f"sleep_{n}")) diff_hw_reg = Tensor.empty(1, dtype=dtypes.ulong) diff_hw_reg = Tensor.custom_kernel(diff_hw_reg, fxn=sleep_kernel)[0] with save_sqtt() as sqtt: