multi custom kernel: support input mixed with copy and shard (#13748)

This commit is contained in:
b1tg
2025-12-30 01:54:27 +08:00
committed by GitHub
parent 0a98fd38b3
commit 63a1bb8507
3 changed files with 13 additions and 3 deletions

View File

@@ -193,6 +193,16 @@ class TestCustomKernel(unittest.TestCase):
err = (tst - (a@b)).square().max()
self.assertLess(err.item(), 1e-6)
def test_gemm_multi(self):
devs = ("CPU:0", "CPU:1")
N = 16
a = Tensor.randn(N, N).shard_(devs, axis=0)
b = Tensor.randn(N, N).to(devs)
c = Tensor(Tensor.empty(N//2, N, device=devs).uop.multi(0), device=devs)
tst = Tensor.custom_kernel(c, a, b, fxn=custom_gemm)[0]
err = (tst - (a@b)).square().max()
self.assertLess(err.item(), 1e-6)
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
# NOTE: grad_fxn doesn't work with pyrender
def test_gemm_backward(self, custom_backward_gemm=False):