mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
multi custom kernel: support input mixed with copy and shard (#13748)
This commit is contained in:
@@ -193,6 +193,16 @@ class TestCustomKernel(unittest.TestCase):
|
||||
err = (tst - (a@b)).square().max()
|
||||
self.assertLess(err.item(), 1e-6)
|
||||
|
||||
def test_gemm_multi(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
N = 16
|
||||
a = Tensor.randn(N, N).shard_(devs, axis=0)
|
||||
b = Tensor.randn(N, N).to(devs)
|
||||
c = Tensor(Tensor.empty(N//2, N, device=devs).uop.multi(0), device=devs)
|
||||
tst = Tensor.custom_kernel(c, a, b, fxn=custom_gemm)[0]
|
||||
err = (tst - (a@b)).square().max()
|
||||
self.assertLess(err.item(), 1e-6)
|
||||
|
||||
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
|
||||
# NOTE: grad_fxn doesn't work with pyrender
|
||||
def test_gemm_backward(self, custom_backward_gemm=False):
|
||||
|
||||
Reference in New Issue
Block a user