mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
multi custom kernel: support input mixed with copy and shard (#13748)
This commit is contained in:
@@ -193,6 +193,16 @@ class TestCustomKernel(unittest.TestCase):
|
||||
err = (tst - (a@b)).square().max()
|
||||
self.assertLess(err.item(), 1e-6)
|
||||
|
||||
def test_gemm_multi(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
N = 16
|
||||
a = Tensor.randn(N, N).shard_(devs, axis=0)
|
||||
b = Tensor.randn(N, N).to(devs)
|
||||
c = Tensor(Tensor.empty(N//2, N, device=devs).uop.multi(0), device=devs)
|
||||
tst = Tensor.custom_kernel(c, a, b, fxn=custom_gemm)[0]
|
||||
err = (tst - (a@b)).square().max()
|
||||
self.assertLess(err.item(), 1e-6)
|
||||
|
||||
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
|
||||
# NOTE: grad_fxn doesn't work with pyrender
|
||||
def test_gemm_backward(self, custom_backward_gemm=False):
|
||||
|
||||
@@ -217,8 +217,8 @@ multi_pm = PatternMatcher([
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
|
||||
(UPat(Ops.CUSTOM_KERNEL, src=UPat(Ops.MULTI), name="ck"),
|
||||
lambda ck: ck.replace(src=tuple(m.src[0] for m in ck.src))),
|
||||
(UPat(Ops.CUSTOM_KERNEL, src=UPat((Ops.MULTI, Ops.CONTIGUOUS)), name="ck"),
|
||||
lambda ck: ck.replace(src=tuple(m.src[0] if m.op is Ops.MULTI else m for m in ck.src))),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
|
||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
|
||||
])+replace_allreduce
|
||||
|
||||
@@ -235,7 +235,7 @@ class Tensor(OpMixin):
|
||||
|
||||
This API is alpha and may change.
|
||||
"""
|
||||
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user