mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
test custom eye function (#13134)
this version is also faster with NOOPT
This commit is contained in:
@@ -9,6 +9,11 @@ def custom_arange_kernel(C:UOp) -> UOp:
|
||||
i = UOp.range(C.size, 0)
|
||||
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}"))
|
||||
|
||||
def custom_eye_kernel(C:UOp) -> UOp:
|
||||
i = UOp.range(C.shape[0], 0)
|
||||
j = UOp.range(C.shape[1], 1)
|
||||
return C[i, j].store((i.eq(j)).cast(C.dtype.base)).end(i, j).sink(arg=KernelInfo(name=f"custom_eye_{C.size}"))
|
||||
|
||||
def custom_add_one_kernel(B:UOp, A:UOp) -> UOp:
|
||||
A,B = A.flatten(), B.flatten()
|
||||
assert B.size == A.size
|
||||
@@ -125,6 +130,12 @@ class TestCustomKernel(unittest.TestCase):
|
||||
tst = tst.custom_kernel(fxn=custom_arange_kernel)[0]
|
||||
self.assertTrue((ref == tst).all().item())
|
||||
|
||||
def test_eye(self):
|
||||
ref = Tensor.eye(1024).contiguous().realize()
|
||||
tst = Tensor.empty_like(ref)
|
||||
tst = tst.custom_kernel(fxn=custom_eye_kernel)[0]
|
||||
self.assertTrue((ref == tst).all().item())
|
||||
|
||||
def test_flip_contract(self):
|
||||
a = Tensor.randn(10,4)
|
||||
b = Tensor.empty_like(a)
|
||||
|
||||
Reference in New Issue
Block a user