diff --git a/test/unit/test_call.py b/test/unit/test_call.py index 273e4385de..95c498362b 100644 --- a/test/unit/test_call.py +++ b/test/unit/test_call.py @@ -237,5 +237,13 @@ class TestCallSchedule(unittest.TestCase): self.assertIsInstance(out.shape[0], UOp) np.testing.assert_allclose(out[:5].numpy(), (np.arange(16*4).reshape(16, 4)[:5] * 2 + 1).astype(np.float32)) + def test_precompile_multi_sharded(self): + @function(precompile=True) + def f(x:Tensor) -> Tensor: return x + 1 + devs = ("CPU:0", "CPU:1") + a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + out = f(a) + 2 + np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3) + if __name__ == '__main__': unittest.main()