From 018c01508d425ffa36d437b5d1d246255a14ebb0 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 13 Mar 2026 06:28:43 -0400 Subject: [PATCH] test case for call precompile multi (#15254) --- test/unit/test_call.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/unit/test_call.py b/test/unit/test_call.py index 273e4385de..95c498362b 100644 --- a/test/unit/test_call.py +++ b/test/unit/test_call.py @@ -237,5 +237,13 @@ class TestCallSchedule(unittest.TestCase): self.assertIsInstance(out.shape[0], UOp) np.testing.assert_allclose(out[:5].numpy(), (np.arange(16*4).reshape(16, 4)[:5] * 2 + 1).astype(np.float32)) + def test_precompile_multi_sharded(self): + @function(precompile=True) + def f(x:Tensor) -> Tensor: return x + 1 + devs = ("CPU:0", "CPU:1") + a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + out = f(a) + 2 + np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3) + if __name__ == '__main__': unittest.main()