test case for call precompile multi (#15254)

This commit is contained in:
chenyu
2026-03-13 06:28:43 -04:00
committed by GitHub
parent bc16f80b50
commit 018c01508d

View File

@@ -237,5 +237,13 @@ class TestCallSchedule(unittest.TestCase):
self.assertIsInstance(out.shape[0], UOp)
np.testing.assert_allclose(out[:5].numpy(), (np.arange(16*4).reshape(16, 4)[:5] * 2 + 1).astype(np.float32))
def test_precompile_multi_sharded(self):
@function(precompile=True)
def f(x:Tensor) -> Tensor: return x + 1
devs = ("CPU:0", "CPU:1")
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
out = f(a) + 2
np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3)
if __name__ == '__main__':
unittest.main()