multi custom kernel: support input mixed with copy and shard (#13748)

This commit is contained in:
b1tg
2025-12-30 01:54:27 +08:00
committed by GitHub
parent 0a98fd38b3
commit 63a1bb8507
3 changed files with 13 additions and 3 deletions

View File

@@ -193,6 +193,16 @@ class TestCustomKernel(unittest.TestCase):
err = (tst - (a@b)).square().max()
self.assertLess(err.item(), 1e-6)
def test_gemm_multi(self):
devs = ("CPU:0", "CPU:1")
N = 16
a = Tensor.randn(N, N).shard_(devs, axis=0)
b = Tensor.randn(N, N).to(devs)
c = Tensor(Tensor.empty(N//2, N, device=devs).uop.multi(0), device=devs)
tst = Tensor.custom_kernel(c, a, b, fxn=custom_gemm)[0]
err = (tst - (a@b)).square().max()
self.assertLess(err.item(), 1e-6)
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
# NOTE: grad_fxn doesn't work with pyrender
def test_gemm_backward(self, custom_backward_gemm=False):

View File

@@ -217,8 +217,8 @@ multi_pm = PatternMatcher([
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
(UPat(Ops.CUSTOM_KERNEL, src=UPat(Ops.MULTI), name="ck"),
lambda ck: ck.replace(src=tuple(m.src[0] for m in ck.src))),
(UPat(Ops.CUSTOM_KERNEL, src=UPat((Ops.MULTI, Ops.CONTIGUOUS)), name="ck"),
lambda ck: ck.replace(src=tuple(m.src[0] if m.op is Ops.MULTI else m for m in ck.src))),
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
])+replace_allreduce

View File

@@ -235,7 +235,7 @@ class Tensor(OpMixin):
This API is alpha and may change.
"""
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
"""