From 06ef8a26b79dbca093e4f3153a47bee2dc4ab102 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 19 Feb 2026 10:45:40 -0500 Subject: [PATCH] add a test case that triggers CALL passthrough_multi (#14887) --- test/unit/test_call.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/unit/test_call.py b/test/unit/test_call.py index f2d9434183..be342d2239 100644 --- a/test/unit/test_call.py +++ b/test/unit/test_call.py @@ -92,5 +92,13 @@ class TestCall(unittest.TestCase): np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5) np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5) + def test_call_plus_sharded(self): + devs = ("CPU:0", "CPU:1") + a = Tensor.ones(10, 10).shard(devs, axis=0) + b = Tensor.ones(10, 10).shard(devs, axis=0) + Tensor.realize(a, b) + c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1)) + np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10))) + if __name__ == '__main__': unittest.main()