From 8c830c5b4427744464665bfb906bfd86d86cbbd6 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 18 Feb 2026 19:23:44 -0500 Subject: [PATCH] test_full_like_shrink_on_shard_axis (#14870) * test_full_like_shrink_on_shard_axis add a test case that triggers non-copy branch in mstack_early_shrink * 0 --- test/backend/test_multitensor.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index 5e9b8fdfb1..2cfc7b1986 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -809,6 +809,14 @@ class TestMultiTensor(unittest.TestCase): t2.realize() def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0) + def test_full_like_shrink_on_shard_axis(self): + t = Tensor.ones(16, 16, dtype=dtypes.int).shard(devices_2, axis=0) + out = Tensor.full_like(t, 2)[:, :8] + sched = out.schedule() + self.assertEqual(len(sched), 2) # TODO: 0. fix mstack_early_shrink + run_schedule(sched) + self.assertEqual(out.tolist(), [[2]*8]*16) + def test_dropout_on_shard(self): with Tensor.train(): X = Tensor.ones(256).to(devices_2)