mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
few more jit tests with multi tensor inputs (#14047)
This commit is contained in:
@@ -541,6 +541,37 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
|
||||
assert len(jf.jit_cache) > 0
|
||||
|
||||
def test_multitensor_jit_in_list(self):
|
||||
# test MULTI tensor inside a list container - exercises the container unpacking + MULTI unpacking
|
||||
@TinyJit
|
||||
def f(a, arr): return (a + arr[0]).realize()
|
||||
for i in range(5):
|
||||
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
|
||||
b = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
|
||||
out = f(a, [b])
|
||||
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.ones(4), atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_multitensor_jit_multiple_inputs(self):
|
||||
# test multiple MULTI tensors as inputs - each gets unpacked to component UOps
|
||||
@TinyJit
|
||||
def f(a, b, c): return (a + b + c).realize()
|
||||
for i in range(5):
|
||||
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
|
||||
b = Tensor.full((4,), i*2).contiguous().realize().shard(devices_2, 0).realize()
|
||||
c = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
|
||||
out = f(a, b, c)
|
||||
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.full(4, i*2) + np.ones(4), atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_multitensor_jit_different_sharding(self):
|
||||
# test MULTI tensors with different sharding - one sharded on axis 0, one broadcast (axis=None)
|
||||
@TinyJit
|
||||
def f(a, b): return (a + b).realize()
|
||||
for i in range(5):
|
||||
a = Tensor.full((4, 4), i).contiguous().realize().shard(devices_2, 0).realize()
|
||||
b = Tensor.full((4, 4), i*2).contiguous().realize().shard(devices_2, None).realize()
|
||||
out = f(a, b)
|
||||
np.testing.assert_allclose(out.numpy(), np.full((4, 4), i) + np.full((4, 4), i*2), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skip("test broken")
|
||||
def test_multi_device_jit_graph(self):
|
||||
if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs")
|
||||
|
||||
@@ -233,7 +233,6 @@ def _prepare_jit_inputs(args, kwargs):
|
||||
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
|
||||
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
|
||||
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||
# TODO: this multi unpack stuff is not well tested.
|
||||
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
if any(lb.base.op is Ops.CONST for lb in lbs):
|
||||
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")
|
||||
|
||||
Reference in New Issue
Block a user