mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move multi axis to property (#8879)
also updated tests so that axis is known prior to realize
This commit is contained in:
@@ -554,24 +554,24 @@ class TestMultiTensor(unittest.TestCase):
|
||||
t4 = t2.reshape((26, 105,))
|
||||
|
||||
for t in [t0, t1, t2, t3, t4]:
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
|
||||
assert t.lazydata.axis == 1
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
|
||||
|
||||
# test shape-one axis
|
||||
t5 = t4.reshape((26, 1, 105))
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
|
||||
assert t5.lazydata.axis == 2
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
|
||||
|
||||
# test split and rejoin to the right and reshape to the left
|
||||
t5 = t0.reshape((2, 13, 3, 5, 7))
|
||||
t6 = t0.reshape((13, 2, 3, 7, 5))
|
||||
t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
|
||||
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
|
||||
assert t5.lazydata.axis == 2
|
||||
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
|
||||
assert t6.lazydata.axis == 2
|
||||
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
|
||||
assert t7.lazydata.axis == 3
|
||||
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
|
||||
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
|
||||
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
|
||||
|
||||
# test no left join
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
@@ -580,8 +580,8 @@ class TestMultiTensor(unittest.TestCase):
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_reshape_on_axis_uneven(self):
|
||||
def reshape_helper(t0, t, t_axis):
|
||||
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
|
||||
assert t.lazydata.axis == t_axis
|
||||
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
|
||||
|
||||
t0 = Tensor.rand((4, 42, 15)).shard(devices_3, axis=1, splits=[14, 7, 21])
|
||||
|
||||
@@ -653,11 +653,11 @@ class TestMultiTensor(unittest.TestCase):
|
||||
def test_rand_like_from_alu(self):
|
||||
# TODO: fix this, which will also fix multi device dropout
|
||||
a = Tensor.ones(4, 4).shard(devices_2, axis=0)
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
(a + a).rand_like()
|
||||
|
||||
b = Tensor.empty(4, 4).shard(devices_2, axis=None)
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
(a + b).rand_like()
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
|
||||
Reference in New Issue
Block a user