update tests that use UOp.size (#15753)

This commit is contained in:
chenyu
2026-04-15 21:58:27 -04:00
committed by GitHub
parent 96092d110c
commit 10c262ced8
4 changed files with 21 additions and 21 deletions

View File

@@ -15,7 +15,7 @@ from extra.gemm.amd_asm_matmul import Kernel
def custom_add_one(A:UOp) -> UOp:
A = A.flatten()
assert dtypes.is_float(A.dtype.base), f"buffer dtype must be float32, got {A.dtype}"
threads = UOp.special(A.size, "lidx0")
threads = UOp.special(A.numel(), "lidx0")
insts = [
s_load_b64(s[0:1], s[0:1], soffset=NULL),
s_waitcnt_lgkmcnt(sdst=NULL, simm16=0),
@@ -27,13 +27,13 @@ def custom_add_one(A:UOp) -> UOp:
global_store_b32(addr=v[0], data=v[1], saddr=s[0:1]),
s_endpgm(),
]
sink = UOp.sink(A.base, threads, arg=KernelInfo(f"custom_add_one_{A.size}", estimates=Estimates(ops=A.size, mem=A.size*4*2)))
sink = UOp.sink(A.base, threads, arg=KernelInfo(f"custom_add_one_{A.numel()}", estimates=Estimates(ops=A.numel(), mem=A.numel()*4*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_add_var(A:UOp, B:UOp) -> UOp:
A,B = A.flatten(), B.flatten()
assert A.dtype.base == dtypes.uint32, f"buffer dtype must be uint32, got {A.dtype}"
threads = UOp.special(A.size, "lidx0")
threads = UOp.special(A.numel(), "lidx0")
var = UOp.variable("var", 0, 10)
insts = [
s_load_b128(s[4:7], s[0:1]),
@@ -46,7 +46,7 @@ def custom_add_var(A:UOp, B:UOp) -> UOp:
global_store_b32(addr=v[0], data=v[1], saddr=s[4:5]),
s_endpgm(),
]
sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(f"custom_add_var_{A.size}"))
sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(f"custom_add_var_{A.numel()}"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_wave_sync(A:UOp, arch:str) -> UOp:
@@ -132,7 +132,7 @@ def custom_handwritten(A:UOp, arch:str) -> UOp:
def custom_data_deps(A:UOp, arch:str) -> UOp:
A = A.flatten()
threads = UOp.special(A.size, "lidx0")
threads = UOp.special(A.numel(), "lidx0")
k = Kernel(arch)
k.emit(s_load_b64(s[0:1], s[0:1], soffset=NULL))
k.emit(s_waitcnt_lgkmcnt(sdst=NULL, simm16=0))

View File

@@ -6,8 +6,8 @@ from tinygrad.uop.ops import KernelInfo, AxisType
# **** kernels ****
def custom_arange_kernel(C:UOp) -> UOp:
i = UOp.range(C.size, 0)
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}"))
i = UOp.range(C.shape[0], 0)
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.shape[0]}"))
def custom_eye_kernel(C:UOp) -> UOp:
i = UOp.range(C.shape[0], 0)
@@ -16,22 +16,22 @@ def custom_eye_kernel(C:UOp) -> UOp:
def custom_add_one_kernel(B:UOp, A:UOp) -> UOp:
A,B = A.flatten(), B.flatten()
assert B.size == A.size
i = UOp.range(A.size, 0)
return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.size}"))
assert B.numel() == A.numel()
i = UOp.range(A.numel(), 0)
return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.numel()}"))
def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp) -> UOp:
C,A,B = C.flatten(), A.flatten(), B.flatten()
i = UOp.range(C.size, 0)
return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.size}")).simplify()
i = UOp.range(C.numel(), 0)
return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.numel()}")).simplify()
def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp:
C,D,A,B = C.flatten(), D.flatten(), A.flatten(), B.flatten()
assert C.size == D.size
i = UOp.range(C.size, 0)
assert C.numel() == D.numel()
i = UOp.range(C.numel(), 0)
store_c = C[i].store(A[i]+B[i])
store_d = D[i].store(A[i]*B[i])
return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name=f"custom_addmul_kernel_{C.size}")).simplify()
return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name=f"custom_addmul_kernel_{C.numel()}")).simplify()
def custom_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
assert A.shape[1] == B.shape[0]
@@ -291,10 +291,10 @@ class TestCustomKernel(unittest.TestCase):
def custom_add_with_tmp(o1:UOp, o2:UOp, A:UOp, B:UOp) -> UOp:
o1,o2,A,B = o1.flatten(), o2.flatten(), A.flatten(), B.flatten()
i = UOp.range(o1.size, 0)
i = UOp.range(o1.numel(), 0)
store_o1 = o1[i].store(A[i]+B[i])
store_o2 = o2[i].store(A[i]+B[i]+2)
return UOp.group(store_o1, store_o2).end(i).sink(arg=KernelInfo(name=f"add_with_tmp_{o1.size}")).simplify()
return UOp.group(store_o1, store_o2).end(i).sink(arg=KernelInfo(name=f"add_with_tmp_{o1.numel()}")).simplify()
from tinygrad import function
@function(precompile=True)

View File

@@ -146,7 +146,7 @@ class TestSchedule(unittest.TestCase):
def test_create_schedule_handles_multi_kernel_after_and_after_deps(self):
def named_copy(name:str):
def fxn(out:UOp, src:UOp) -> UOp:
i = UOp.range(src.size, 0)
i = UOp.range(src.shape[0], 0)
return out[i].store(src[i]).end(i).sink(arg=KernelInfo(name=name))
return fxn

View File

@@ -418,8 +418,8 @@ class TestFunctionTuple(unittest.TestCase):
def test_custom_kernel_save_unused_output(self):
def my_kernel(C:UOp, D:UOp, A:UOp) -> UOp:
i = UOp.range(A.size, 0)
j = UOp.range(D.size, 1)
i = UOp.range(A.shape[0], 0)
j = UOp.range(D.shape[0], 1)
store_c = C[i].store(A[i] * 2.0).end(i)
store_d = D[j].store(A[j]).end(j)
return UOp.group(store_c, store_d).sink(arg=KernelInfo(name="my_kernel"))
@@ -444,7 +444,7 @@ class TestFunctionTuple(unittest.TestCase):
def test_custom_kernel_both_outputs_used(self):
def my_kernel(C:UOp, D:UOp, A:UOp) -> UOp:
i = UOp.range(A.size, 0)
i = UOp.range(A.shape[0], 0)
store_c = C[i].store(A[i] * 2.0)
store_d = D[i].store(A[i] * 3.0)
return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name="my_kernel"))