test custom eye function (#13134)

this version is also faster with NOOPT
This commit is contained in:
chenyu
2025-11-06 14:51:55 -05:00
committed by GitHub
parent 290441dd44
commit bfb0c0391f

View File

@@ -9,6 +9,11 @@ def custom_arange_kernel(C:UOp) -> UOp:
i = UOp.range(C.size, 0)
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}"))
def custom_eye_kernel(C:UOp) -> UOp:
i = UOp.range(C.shape[0], 0)
j = UOp.range(C.shape[1], 1)
return C[i, j].store((i.eq(j)).cast(C.dtype.base)).end(i, j).sink(arg=KernelInfo(name=f"custom_eye_{C.size}"))
def custom_add_one_kernel(B:UOp, A:UOp) -> UOp:
A,B = A.flatten(), B.flatten()
assert B.size == A.size
@@ -125,6 +130,12 @@ class TestCustomKernel(unittest.TestCase):
tst = tst.custom_kernel(fxn=custom_arange_kernel)[0]
self.assertTrue((ref == tst).all().item())
def test_eye(self):
ref = Tensor.eye(1024).contiguous().realize()
tst = Tensor.empty_like(ref)
tst = tst.custom_kernel(fxn=custom_eye_kernel)[0]
self.assertTrue((ref == tst).all().item())
def test_flip_contract(self):
a = Tensor.randn(10,4)
b = Tensor.empty_like(a)