move multi axis to property (#8879)

also updated tests so that axis is known prior to realize
This commit is contained in:
chenyu
2025-02-03 16:02:09 -05:00
committed by GitHub
parent fa90079370
commit 746d899dbd
3 changed files with 32 additions and 36 deletions

View File

@@ -554,24 +554,24 @@ class TestMultiTensor(unittest.TestCase):
t4 = t2.reshape((26, 105,))
for t in [t0, t1, t2, t3, t4]:
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
assert t.lazydata.axis == 1
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
# test shape-one axis
t5 = t4.reshape((26, 1, 105))
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
assert t5.lazydata.axis == 2
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
# test split and rejoin to the right and reshape to the left
t5 = t0.reshape((2, 13, 3, 5, 7))
t6 = t0.reshape((13, 2, 3, 7, 5))
t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
assert t5.lazydata.axis == 2
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
assert t6.lazydata.axis == 2
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
assert t7.lazydata.axis == 3
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
# test no left join
with self.assertRaises((AssertionError, ValueError)):
@@ -580,8 +580,8 @@ class TestMultiTensor(unittest.TestCase):
@unittest.skip("no longer supports uneven shard")
def test_reshape_on_axis_uneven(self):
def reshape_helper(t0, t, t_axis):
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
assert t.lazydata.axis == t_axis
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21])
@@ -653,11 +653,11 @@ class TestMultiTensor(unittest.TestCase):
def test_rand_like_from_alu(self):
# TODO: fix this, which will also fix multi device dropout
a = Tensor.ones(4, 4).shard(devices_2, axis=0)
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
(a + a).rand_like()
b = Tensor.empty(4, 4).shard(devices_2, axis=None)
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
(a + b).rand_like()
@unittest.skip("no longer supports uneven shard")