mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
slicing + allclose
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, UOp, Context
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.uop.ops import KernelInfo, AxisType
|
||||
|
||||
# **** kernels ****
|
||||
@@ -9,15 +10,18 @@ def custom_arange_kernel(C:UOp) -> UOp:
|
||||
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}"))
|
||||
|
||||
def custom_add_one_kernel(B:UOp, A:UOp) -> UOp:
|
||||
A,B = A.flatten(), B.flatten()
|
||||
assert B.size == A.size
|
||||
i = UOp.range(A.size, 0)
|
||||
return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.size}"))
|
||||
|
||||
def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
C,A,B = C.flatten(), A.flatten(), B.flatten()
|
||||
i = UOp.range(C.size, 0)
|
||||
return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.size}")).simplify()
|
||||
|
||||
def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp:
|
||||
C,D,A,B = C.flatten(), D.flatten(), A.flatten(), B.flatten()
|
||||
assert C.size == D.size
|
||||
i = UOp.range(C.size, 0)
|
||||
store_c = C[i].store(A[i]+B[i])
|
||||
@@ -39,13 +43,22 @@ def custom_sum(B:UOp, A:UOp) -> UOp:
|
||||
return B.sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=()))
|
||||
|
||||
def flip_contract_kernel(dest:UOp, src:UOp):
|
||||
assert dest.size%4 == 0
|
||||
i = UOp.range(dest.size//4, 0)
|
||||
j = UOp.range(4, 1, AxisType.UPCAST)
|
||||
vec = src[i*4+j].contract(j)
|
||||
store = UOp.group(*[dest[i*4+k].store(vec.gep(3-k)) for k in range(4)])
|
||||
i = UOp.range(dest.shape[0], 0)
|
||||
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
|
||||
vec = src[i, j].contract(j)
|
||||
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
|
||||
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
|
||||
|
||||
def slice_sum_kernel(dest:UOp, src:UOp):
|
||||
G = UOp.range(src.shape[0], 0, AxisType.GLOBAL)
|
||||
slice_src = src[G, :]
|
||||
reg = UOp.placeholder((1,), dest.dtype.base, 0, addrspace=AddrSpace.REG)
|
||||
reg = reg.after(G)[0].set(0)
|
||||
R = UOp.range(src.shape[1], 1, AxisType.REDUCE)
|
||||
reg = reg[0].set(reg[0] + slice_src[R], end=R)
|
||||
ast = dest[G].set(reg[0], end=G)
|
||||
return ast.sink(arg=KernelInfo(name=f"slice_sum_{src.shape[0]}_{src.shape[1]}", opts_to_apply=()))
|
||||
|
||||
# **** backward callbacks ****
|
||||
|
||||
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
|
||||
@@ -111,6 +124,12 @@ class TestCustomKernel(unittest.TestCase):
|
||||
b = Tensor.custom_kernel(tst, a, fxn=custom_sum)[0]
|
||||
self.assertEqual(b.item(), 15)
|
||||
|
||||
def test_slice_sum(self):
|
||||
A = Tensor.randn(16, 16)
|
||||
B = Tensor.empty(16)
|
||||
B = Tensor.custom_kernel(B, A, fxn=slice_sum_kernel)[0]
|
||||
self.assertTrue(B.allclose(A.sum(1)))
|
||||
|
||||
def test_gemm(self):
|
||||
N = 16
|
||||
a = Tensor.randn(N, N)
|
||||
|
||||
Reference in New Issue
Block a user