mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add CONTRACT support to UOp programs (#13043)
* add contract support * use contract * 342 tflops
This commit is contained in:
@@ -87,6 +87,7 @@ def custom_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
BN_Bs_stride = (BLOCK_N + 0)
|
||||
INNER_SLICE = 8
|
||||
As = UOp.placeholder((BLOCK_K//INNER_SLICE, BM_As_stride, INNER_SLICE), dtypes.half, slot=slot, addrspace=AddrSpace.LOCAL)
|
||||
INNER_SLICE = 1
|
||||
Bs = UOp.placeholder((BLOCK_K//INNER_SLICE, BN_Bs_stride, INNER_SLICE), dtypes.half, slot=slot+1, addrspace=AddrSpace.LOCAL)
|
||||
As = As.permute((0,2,1)).reshape((BLOCK_K, BM_As_stride)).shrink_to((BLOCK_K, BLOCK_M))
|
||||
Bs = Bs.permute((0,2,1)).reshape((BLOCK_K, BN_Bs_stride)).shrink_to((BLOCK_K, BLOCK_N))
|
||||
@@ -116,18 +117,20 @@ def custom_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
Ar = UOp.placeholder((BLOCK_M//TC_M//WARPGROUP_SIZE,), dtypes.half.vec(8), slot=1, addrspace=AddrSpace.REG)
|
||||
Br = UOp.placeholder((BLOCK_N//TC_N,), dtypes.half.vec(8), slot=2, addrspace=AddrSpace.REG)
|
||||
|
||||
M_load_loop = UOp.range(BLOCK_M//TC_M//WARPGROUP_SIZE, rng+1)
|
||||
M_load_loop = UOp.range(BLOCK_M//TC_M//WARPGROUP_SIZE, rng+10)
|
||||
Asl = Asl.reshape((BLOCK_K//TC_K, TC_K, BLOCK_M//TC_M//WARPGROUP_SIZE, WARPGROUP_SIZE, TC_M))
|
||||
A_in = UOp.vectorize(*[Asl[K_inner_loop, (warp//16)*8+i, M_load_loop, warpgroup, warp%16] for i in range(8)])
|
||||
load_rng = UOp.range(8, rng+11, axis_type=AxisType.UPCAST)
|
||||
A_in = Asl[K_inner_loop, (warp//16)*8+load_rng, M_load_loop, warpgroup, warp%16].contract(load_rng)
|
||||
Ar = Ar[M_load_loop].set(A_in, end=M_load_loop)
|
||||
|
||||
N_load_loop = UOp.range(BLOCK_N//TC_N, rng+2)
|
||||
N_load_loop = UOp.range(BLOCK_N//TC_N, rng+20)
|
||||
Bsl = Bsl.reshape((BLOCK_K//TC_K, TC_K, BLOCK_N//TC_N, TC_N))
|
||||
B_in = UOp.vectorize(*[Bsl[K_inner_loop, (warp//16)*8+i, N_load_loop, warp%16] for i in range(8)])
|
||||
load_rng = UOp.range(8, rng+21, axis_type=AxisType.UPCAST)
|
||||
B_in = Bsl[K_inner_loop, (warp//16)*8+load_rng, N_load_loop, warp%16].contract(load_rng)
|
||||
Br = Br[N_load_loop].set(B_in, end=N_load_loop)
|
||||
|
||||
M_inner_loop = UOp.range(BLOCK_M//TC_M//WARPGROUP_SIZE, rng+3)
|
||||
N_inner_loop = UOp.range(BLOCK_N//TC_N, rng+4)
|
||||
M_inner_loop = UOp.range(BLOCK_M//TC_M//WARPGROUP_SIZE, rng+30)
|
||||
N_inner_loop = UOp.range(BLOCK_N//TC_N, rng+31)
|
||||
|
||||
# load values
|
||||
acc_after = acc.after(*afters, M_inner_loop, N_inner_loop, K_inner_loop)
|
||||
|
||||
@@ -37,6 +37,14 @@ def custom_sum(B:UOp, A:UOp) -> UOp:
|
||||
i = UOp.range(A.shape[0], 0, axis_type=AxisType.REDUCE)
|
||||
return B[0].store(A[i].reduce(i, arg=Ops.ADD)).sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=()))
|
||||
|
||||
def flip_contract_kernel(dest:UOp, src:UOp):
|
||||
assert dest.size%4 == 0
|
||||
i = UOp.range(dest.size//4, 0)
|
||||
j = UOp.range(4, 1, AxisType.UPCAST)
|
||||
vec = src[i*4+j].contract(j)
|
||||
store = UOp.group(*[dest[i*4+k].store(vec.gep(3-k)) for k in range(4)])
|
||||
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
|
||||
|
||||
# **** backward callbacks ****
|
||||
|
||||
def backward_gemm(gradient:UOp, kernel:UOp) -> tuple[UOp, UOp]:
|
||||
@@ -82,6 +90,12 @@ class TestCustomKernel(unittest.TestCase):
|
||||
tst = tst.custom_kernel(fxn=custom_arange_kernel)[0]
|
||||
self.assertTrue((ref == tst).all().item())
|
||||
|
||||
def test_flip_contract(self):
|
||||
a = Tensor.randn(10,4)
|
||||
b = Tensor.empty_like(a)
|
||||
b = b.custom_kernel(a, fxn=flip_contract_kernel)[0]
|
||||
self.assertTrue((a.flip(1) == b).all().item())
|
||||
|
||||
def test_noncontig(self):
|
||||
a = Tensor.ones(16, 16).contiguous()
|
||||
tst = Tensor.empty_like(a)
|
||||
|
||||
@@ -375,6 +375,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
def contract(self, *rngs:UOp):
|
||||
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
|
||||
return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs))
|
||||
def alu(self, op, *src:UOp, **kwargs):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
|
||||
Reference in New Issue
Block a user