diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index 5e9b8fdfb1..2cfc7b1986 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -809,6 +809,14 @@ class TestMultiTensor(unittest.TestCase): t2.realize() def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0) + def test_full_like_shrink_on_shard_axis(self): + t = Tensor.ones(16, 16, dtype=dtypes.int).shard(devices_2, axis=0) + out = Tensor.full_like(t, 2)[:, :8] + sched = out.schedule() + self.assertEqual(len(sched), 2) # TODO: 0. fix mstack_early_shrink + run_schedule(sched) + self.assertEqual(out.tolist(), [[2]*8]*16) + def test_dropout_on_shard(self): with Tensor.train(): X = Tensor.ones(256).to(devices_2)