mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
multi custom kernel: support input mixed with copy and shard (#13748)
This commit is contained in:
@@ -193,6 +193,16 @@ class TestCustomKernel(unittest.TestCase):
|
|||||||
err = (tst - (a@b)).square().max()
|
err = (tst - (a@b)).square().max()
|
||||||
self.assertLess(err.item(), 1e-6)
|
self.assertLess(err.item(), 1e-6)
|
||||||
|
|
||||||
|
def test_gemm_multi(self):
|
||||||
|
devs = ("CPU:0", "CPU:1")
|
||||||
|
N = 16
|
||||||
|
a = Tensor.randn(N, N).shard_(devs, axis=0)
|
||||||
|
b = Tensor.randn(N, N).to(devs)
|
||||||
|
c = Tensor(Tensor.empty(N//2, N, device=devs).uop.multi(0), device=devs)
|
||||||
|
tst = Tensor.custom_kernel(c, a, b, fxn=custom_gemm)[0]
|
||||||
|
err = (tst - (a@b)).square().max()
|
||||||
|
self.assertLess(err.item(), 1e-6)
|
||||||
|
|
||||||
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
|
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
|
||||||
# NOTE: grad_fxn doesn't work with pyrender
|
# NOTE: grad_fxn doesn't work with pyrender
|
||||||
def test_gemm_backward(self, custom_backward_gemm=False):
|
def test_gemm_backward(self, custom_backward_gemm=False):
|
||||||
|
|||||||
@@ -217,8 +217,8 @@ multi_pm = PatternMatcher([
|
|||||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||||
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
|
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
|
||||||
(UPat(Ops.CUSTOM_KERNEL, src=UPat(Ops.MULTI), name="ck"),
|
(UPat(Ops.CUSTOM_KERNEL, src=UPat((Ops.MULTI, Ops.CONTIGUOUS)), name="ck"),
|
||||||
lambda ck: ck.replace(src=tuple(m.src[0] for m in ck.src))),
|
lambda ck: ck.replace(src=tuple(m.src[0] if m.op is Ops.MULTI else m for m in ck.src))),
|
||||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
|
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
|
||||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
|
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
|
||||||
])+replace_allreduce
|
])+replace_allreduce
|
||||||
|
|||||||
@@ -235,7 +235,7 @@ class Tensor(OpMixin):
|
|||||||
|
|
||||||
This API is alpha and may change.
|
This API is alpha and may change.
|
||||||
"""
|
"""
|
||||||
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
return [Tensor(u, device=u.device) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||||
|
|
||||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user