mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
hotfix: types and names for custom kernel test
This commit is contained in:
@@ -4,27 +4,27 @@ from tinygrad.uop.ops import KernelInfo, AxisType
|
||||
|
||||
# **** kernels ****
|
||||
|
||||
def custom_arange_kernel(C:UOp):
|
||||
def custom_arange_kernel(C:UOp) -> UOp:
|
||||
i = UOp.range(C.size, 0)
|
||||
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}"))
|
||||
|
||||
def custom_add_one_kernel(B:UOp, A:UOp):
|
||||
def custom_add_one_kernel(B:UOp, A:UOp) -> UOp:
|
||||
assert B.size == A.size
|
||||
i = UOp.range(A.size, 0)
|
||||
return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.size}"))
|
||||
|
||||
def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp):
|
||||
def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
i = UOp.range(C.size, 0)
|
||||
return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.size}")).simplify()
|
||||
|
||||
def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp):
|
||||
def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp:
|
||||
assert C.size == D.size
|
||||
i = UOp.range(C.size, 0)
|
||||
store_c = C[i].store(A[i]+B[i])
|
||||
store_d = D[i].store(A[i]*B[i])
|
||||
return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name=f"custom_addmul_kernel_{C.size}")).simplify()
|
||||
|
||||
def custom_gemm(C:UOp, A:UOp, B:UOp):
|
||||
def custom_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
assert A.shape[1] == B.shape[0]
|
||||
i, j, k = UOp.range(C.shape[0], 0), UOp.range(C.shape[1], 1), UOp.range(A.shape[1], 2, axis_type=AxisType.REDUCE)
|
||||
C = C[i, j].set(0.0)
|
||||
@@ -34,14 +34,14 @@ def custom_gemm(C:UOp, A:UOp, B:UOp):
|
||||
|
||||
# **** backward callbacks ****
|
||||
|
||||
def backward_gemm(gradient:UOp, k:UOp) -> tuple[UOp, UOp]:
|
||||
out, a, b = k.src
|
||||
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
|
||||
out, a, b = kernel.src
|
||||
grad_a = (Tensor(gradient) @ Tensor(b).T).uop
|
||||
grad_b = (Tensor(a).T @ Tensor(gradient)).uop
|
||||
return (None, grad_a, grad_b)
|
||||
|
||||
def backward_gemm_custom(gradient:UOp, k:UOp) -> tuple[UOp, UOp]:
|
||||
out, a, b = k.src
|
||||
def backward_gemm_custom(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
|
||||
out, a, b = kernel.src
|
||||
grad_a = Tensor.empty_like(Tensor(a)).custom_kernel(Tensor(gradient), Tensor(b).T, fxn=custom_gemm)[0].uop
|
||||
grad_b = Tensor.empty_like(Tensor(b)).custom_kernel(Tensor(a).T, Tensor(gradient), fxn=custom_gemm)[0].uop
|
||||
return (None, grad_a, grad_b)
|
||||
|
||||
Reference in New Issue
Block a user