mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
minor cleanup to reshape arg handling (#5070)
moved None handle to be with argfix, and only resolve -1 if there's a -1
This commit is contained in:
@@ -1123,12 +1123,14 @@ class TestOps(unittest.TestCase):
|
||||
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,0,1,2,3)), lambda x: x.permute((0,0,1,2,3)), expected=RuntimeError)
|
||||
|
||||
def test_reshape(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.reshape((12,6,6)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,3,6,6)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,1,6,6)))
|
||||
helper_test_op([()], lambda x: x.reshape([]))
|
||||
helper_test_op([(1,)], lambda x: x.reshape([]))
|
||||
helper_test_op([()], lambda x: x.reshape([1]))
|
||||
helper_test_op([()], lambda x: x.reshape([1,1,1]))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.reshape((4,3,6,6)), lambda x: x.reshape((None,None,6,6)))
|
||||
helper_test_op([()], lambda x: x.reshape(()))
|
||||
helper_test_op([(1,)], lambda x: x.reshape(()))
|
||||
helper_test_op([()], lambda x: x.reshape((1,)))
|
||||
helper_test_op([()], lambda x: x.reshape((1,1,1)))
|
||||
self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,2)), lambda x: x.reshape((-1,-1,2)), expected=RuntimeError)
|
||||
self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,-1,2)), lambda x: x.reshape((-1,-1,-1,2)), expected=RuntimeError)
|
||||
|
||||
|
||||
@@ -766,9 +766,11 @@ class Tensor:
|
||||
print(t.reshape(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
new_shape = argfix(shape, *args)
|
||||
if new_shape.count(-1) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
|
||||
# resolve None and args
|
||||
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
|
||||
# resolve -1
|
||||
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||
elif c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
||||
|
||||
def expand(self, shape, *args) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user