mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test case for call precompile multi (#15254)
This commit is contained in:
@@ -237,5 +237,13 @@ class TestCallSchedule(unittest.TestCase):
|
||||
self.assertIsInstance(out.shape[0], UOp)
|
||||
np.testing.assert_allclose(out[:5].numpy(), (np.arange(16*4).reshape(16, 4)[:5] * 2 + 1).astype(np.float32))
|
||||
|
||||
def test_precompile_multi_sharded(self):
|
||||
@function(precompile=True)
|
||||
def f(x:Tensor) -> Tensor: return x + 1
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
|
||||
out = f(a) + 2
|
||||
np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user