mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test_full_like_shrink_on_shard_axis (#14870)
* test_full_like_shrink_on_shard_axis add a test case that triggers non-copy branch in mstack_early_shrink * 0
This commit is contained in:
@@ -809,6 +809,14 @@ class TestMultiTensor(unittest.TestCase):
|
||||
t2.realize()
|
||||
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
|
||||
|
||||
def test_full_like_shrink_on_shard_axis(self):
|
||||
t = Tensor.ones(16, 16, dtype=dtypes.int).shard(devices_2, axis=0)
|
||||
out = Tensor.full_like(t, 2)[:, :8]
|
||||
sched = out.schedule()
|
||||
self.assertEqual(len(sched), 2) # TODO: 0. fix mstack_early_shrink
|
||||
run_schedule(sched)
|
||||
self.assertEqual(out.tolist(), [[2]*8]*16)
|
||||
|
||||
def test_dropout_on_shard(self):
|
||||
with Tensor.train():
|
||||
X = Tensor.ones(256).to(devices_2)
|
||||
|
||||
Reference in New Issue
Block a user