minor cleanup to reshape arg handling (#5070)

moved None handle to be with argfix, and only resolve -1 if there's a -1
This commit is contained in:
chenyu
2024-06-20 10:27:27 -04:00
committed by GitHub
parent f4355d0f1b
commit 50700171ef
2 changed files with 11 additions and 7 deletions

View File

@@ -1123,12 +1123,14 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,0,1,2,3)), lambda x: x.permute((0,0,1,2,3)), expected=RuntimeError)
def test_reshape(self):
helper_test_op([(4,3,6,6)], lambda x: x.reshape((12,6,6)))
helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,3,6,6)))
helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,1,6,6)))
helper_test_op([()], lambda x: x.reshape([]))
helper_test_op([(1,)], lambda x: x.reshape([]))
helper_test_op([()], lambda x: x.reshape([1]))
helper_test_op([()], lambda x: x.reshape([1,1,1]))
helper_test_op([(4,3,6,6)], lambda x: x.reshape((4,3,6,6)), lambda x: x.reshape((None,None,6,6)))
helper_test_op([()], lambda x: x.reshape(()))
helper_test_op([(1,)], lambda x: x.reshape(()))
helper_test_op([()], lambda x: x.reshape((1,)))
helper_test_op([()], lambda x: x.reshape((1,1,1)))
self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,2)), lambda x: x.reshape((-1,-1,2)), expected=RuntimeError)
self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,-1,2)), lambda x: x.reshape((-1,-1,-1,2)), expected=RuntimeError)

View File

@@ -766,9 +766,11 @@ class Tensor:
print(t.reshape(2, 3).numpy())
```
"""
new_shape = argfix(shape, *args)
if new_shape.count(-1) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
# resolve None and args
new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
# resolve -1
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
elif c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
def expand(self, shape, *args) -> Tensor: