diff --git a/extra/assembly/amd/test/test_custom_kernel.py b/extra/assembly/amd/test/test_custom_kernel.py index 199cbc5f5f..6102b217c7 100644 --- a/extra/assembly/amd/test/test_custom_kernel.py +++ b/extra/assembly/amd/test/test_custom_kernel.py @@ -34,6 +34,26 @@ def custom_add_one(A:UOp, arch:str) -> UOp: sink = UOp.sink(A.base, threads, arg=KernelInfo(name:=f"custom_add_one_{A.size}", estimates=Estimates(ops=A.size, mem=A.size*4*2))) return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=(*sink.src, sink)), *assemble_insts(insts, name, arch))) +def custom_add_var(A:UOp, B:UOp, arch:str) -> UOp: + A,B = A.flatten(), B.flatten() + assert A.dtype.base == dtypes.uint32, f"buffer dtype must be uint32, got {A.dtype}" + threads = UOp.special(A.size, "lidx0") + var = UOp.variable("var", 0, 10) + insts = [ + s_load_b128(s[4:7], s[0:1]), + s_load_b32(s[8], s[0:1], offset=0x10), # all threads load the same variable + s_waitcnt(lgkmcnt=0), + v_lshlrev_b32_e32(v[0], 2, v[0]), # element offset, different per thread + global_load_b32(v[1], v[0], saddr=s[6:7]), + s_waitcnt(vmcnt=0), + v_add_nc_u32_e32(v[1], s[8], v[1]), + global_store_b32(addr=v[0], data=v[1], saddr=s[4:5]), + s_endpgm(), + ] + sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(name:=f"custom_add_one_{A.size}")) + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=(*sink.src, sink)), + *assemble_insts(insts, name, arch, kernarg_size=16))) + class TestCustomKernel(unittest.TestCase): def test_simple(self): a = Tensor.full((16, 16), 1.).contiguous().realize() @@ -44,5 +64,14 @@ class TestCustomKernel(unittest.TestCase): ei.run() self.assertTrue((a.numpy() == 2.).all()) + def test_variable(self): + b = Tensor.full((16, 16), 1, dtype=dtypes.uint32).contiguous().realize() + a = Tensor.zeros_like(b).contiguous().realize() + a = Tensor.custom_kernel(a, b, fxn=functools.partial(custom_add_var, arch=Device[Device.DEFAULT].arch))[0] + ei = a.schedule()[-1].lower() + for i in range(4): + ei.run({"var":i}) + self.assertTrue((a.numpy() == 1+i).all()) + if __name__ == "__main__": unittest.main()