mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
multi custom kernel support (#13716)
* multi custom kernel support * custom kernel xfrom * works * no SPEC=2 on ck * panic * touchups
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, UOp, Context
|
||||
from tinygrad import Tensor, UOp
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.uop.ops import KernelInfo, AxisType
|
||||
|
||||
@@ -117,6 +117,17 @@ class TestCustomKernel(unittest.TestCase):
|
||||
out = c.flatten().tolist()
|
||||
assert all(x == 2 for x in out), "all 2"
|
||||
|
||||
def test_simple_sharded(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
|
||||
a = Tensor.ones(16, 16).contiguous().shard(devs, axis=0)
|
||||
b = Tensor.ones(16, 16).contiguous().shard(devs, axis=0)
|
||||
# ugly construction to get a sharded empty tensor
|
||||
c = Tensor(Tensor.empty(8, 16, device=devs).uop.multi(0), device=devs)
|
||||
c = Tensor.custom_kernel(c,a,b, fxn=custom_elementwise_add_kernel)[0]
|
||||
out = c.flatten().tolist()
|
||||
assert all(x == 2 for x in out), "all 2"
|
||||
|
||||
def test_multioutput(self):
|
||||
a = Tensor.full((16, 16), 3.).contiguous()
|
||||
b = Tensor.full((16, 16), 3.).contiguous()
|
||||
@@ -184,7 +195,6 @@ class TestCustomKernel(unittest.TestCase):
|
||||
|
||||
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
|
||||
# NOTE: grad_fxn doesn't work with pyrender
|
||||
@Context(SPEC=1)
|
||||
def test_gemm_backward(self, custom_backward_gemm=False):
|
||||
N = 4
|
||||
a_rand = Tensor.randn(N, 8)
|
||||
|
||||
Reference in New Issue
Block a user