diff --git a/test/test_custom_kernel.py b/test/test_custom_kernel.py index 40a6d2a71c..7defac9c20 100644 --- a/test/test_custom_kernel.py +++ b/test/test_custom_kernel.py @@ -193,6 +193,16 @@ class TestCustomKernel(unittest.TestCase): err = (tst - (a@b)).square().max() self.assertLess(err.item(), 1e-6) + def test_gemm_multi(self): + devs = ("CPU:0", "CPU:1") + N = 16 + a = Tensor.randn(N, N).shard_(devs, axis=0) + b = Tensor.randn(N, N).to(devs) + c = Tensor(Tensor.empty(N//2, N, device=devs).uop.multi(0), device=devs) + tst = Tensor.custom_kernel(c, a, b, fxn=custom_gemm)[0] + err = (tst - (a@b)).square().max() + self.assertLess(err.item(), 1e-6) + def test_gemm_backward_custom(self): self.test_gemm_backward(True) # NOTE: grad_fxn doesn't work with pyrender def test_gemm_backward(self, custom_backward_gemm=False): diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 145b3850be..23a81ea6c7 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -217,8 +217,8 @@ multi_pm = PatternMatcher([ (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), # multi supports custom kernels with CUSTOM_KERNEL + AFTER - (UPat(Ops.CUSTOM_KERNEL, src=UPat(Ops.MULTI), name="ck"), - lambda ck: ck.replace(src=tuple(m.src[0] for m in ck.src))), + (UPat(Ops.CUSTOM_KERNEL, src=UPat((Ops.MULTI, Ops.CONTIGUOUS)), name="ck"), + lambda ck: ck.replace(src=tuple(m.src[0] if m.op is Ops.MULTI else m for m in ck.src))), (UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"), lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)) ])+replace_allreduce diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 447eba179a..d0c531bb95 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -235,7 +235,7 @@ class Tensor(OpMixin): This API is alpha and may change. """ - return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)] + return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)] def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]: """