mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Missing features from rearrange (#6184)
* fixes and tests * typo in test
This commit is contained in:
@@ -61,6 +61,14 @@ class test_rearrange_examples(unittest.TestCase):
|
||||
y = y.rearrange("b c -> c b () ()")
|
||||
assert tuple(y.shape) == (20, 10, 1, 1)
|
||||
|
||||
def test9(self):
|
||||
x = Tensor(np.arange(10 * 20 * 1 * 1).reshape([10, 20, 1, 1]))
|
||||
# squeeze - unsqueeze
|
||||
y = x.rearrange("b c 1 1 -> b c")
|
||||
assert tuple(y.shape) == (10, 20)
|
||||
y = y.rearrange("b1 c -> c b1 1 1")
|
||||
assert tuple(y.shape) == (20, 10, 1, 1)
|
||||
|
||||
def test_tensor_train_example_numpy(self):
|
||||
# kept here just for a collection, only tested for numpy
|
||||
# https://arxiv.org/pdf/1509.06569.pdf, (5)
|
||||
@@ -131,6 +139,9 @@ class test_rearrange_ops(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
## incorrect dimension provided for an axis that is only permuted
|
||||
y.rearrange("(a1 a2 a3) b -> b a3 a2 a1", a1=2, a2=2, b=2)
|
||||
with self.assertRaises(AssertionError):
|
||||
## unused axis provided
|
||||
y.rearrange("(a b c) d -> a b c d", b=2, c=2, e=2)
|
||||
|
||||
def test_rearrange_ellipsis_ops(self):
|
||||
identity_patterns = [
|
||||
|
||||
Reference in New Issue
Block a user