don't allow MLB assigns with different axes (#3483)

* allow LB <- MLB assign, but don't reuse buffer

* update test

* update test

* assign assert axes are the same
This commit is contained in:
David Hou
2024-02-29 20:04:12 -08:00
committed by GitHub
parent 35d998efa8
commit f19d8bb7b4
2 changed files with 11 additions and 0 deletions

View File

@@ -355,6 +355,15 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_allclose(t0.numpy().flatten(), t1.numpy().flatten())
assert t1.lazydata.axis == 2
def test_mlb_assign_change_axis(self):
devices = (d0, d1)
t_none = Tensor.zeros((16, 16)).shard(devices).contiguous().realize()
t_zero = Tensor.ones((16, 16)).shard(devices, axis=0)
with self.assertRaises(AssertionError):
# don't allow assigns that change axes
t_none.assign(t_zero)
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
# shrink a multitensor on sharded axis