test_full_like_shrink_on_shard_axis (#14870)

* test_full_like_shrink_on_shard_axis

add a test case that triggers non-copy branch in mstack_early_shrink

* 0
This commit is contained in:
chenyu
2026-02-18 19:23:44 -05:00
committed by GitHub
parent 4005e9db6d
commit 8c830c5b44

View File

@@ -809,6 +809,14 @@ class TestMultiTensor(unittest.TestCase):
t2.realize()
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_full_like_shrink_on_shard_axis(self):
t = Tensor.ones(16, 16, dtype=dtypes.int).shard(devices_2, axis=0)
out = Tensor.full_like(t, 2)[:, :8]
sched = out.schedule()
self.assertEqual(len(sched), 2) # TODO: 0. fix mstack_early_shrink
run_schedule(sched)
self.assertEqual(out.tolist(), [[2]*8]*16)
def test_dropout_on_shard(self):
with Tensor.train():
X = Tensor.ones(256).to(devices_2)