few more jit tests with multi tensor inputs (#14047)

This commit is contained in:
chenyu
2026-01-06 22:05:22 -05:00
committed by GitHub
parent 72a3f78d19
commit 2833c5a54b
2 changed files with 31 additions and 1 deletions

View File

@@ -541,6 +541,37 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
assert len(jf.jit_cache) > 0
def test_multitensor_jit_in_list(self):
# test MULTI tensor inside a list container - exercises the container unpacking + MULTI unpacking
@TinyJit
def f(a, arr): return (a + arr[0]).realize()
for i in range(5):
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
b = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
out = f(a, [b])
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.ones(4), atol=1e-4, rtol=1e-5)
def test_multitensor_jit_multiple_inputs(self):
# test multiple MULTI tensors as inputs - each gets unpacked to component UOps
@TinyJit
def f(a, b, c): return (a + b + c).realize()
for i in range(5):
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
b = Tensor.full((4,), i*2).contiguous().realize().shard(devices_2, 0).realize()
c = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
out = f(a, b, c)
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.full(4, i*2) + np.ones(4), atol=1e-4, rtol=1e-5)
def test_multitensor_jit_different_sharding(self):
# test MULTI tensors with different sharding - one sharded on axis 0, one broadcast (axis=None)
@TinyJit
def f(a, b): return (a + b).realize()
for i in range(5):
a = Tensor.full((4, 4), i).contiguous().realize().shard(devices_2, 0).realize()
b = Tensor.full((4, 4), i*2).contiguous().realize().shard(devices_2, None).realize()
out = f(a, b)
np.testing.assert_allclose(out.numpy(), np.full((4, 4), i) + np.full((4, 4), i*2), atol=1e-4, rtol=1e-5)
@unittest.skip("test broken")
def test_multi_device_jit_graph(self):
if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs")

View File

@@ -233,7 +233,6 @@ def _prepare_jit_inputs(args, kwargs):
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
# TODO: this multi unpack stuff is not well tested.
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
if any(lb.base.op is Ops.CONST for lb in lbs):
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")