Tensor.shrink arg cleanup (#7700)

removed duplicated logic
This commit is contained in:
chenyu
2024-11-14 13:01:22 -05:00
committed by GitHub
parent 9fb396f660
commit 888fcb3643

View File

@@ -1001,8 +1001,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(t.shrink((((0, 2), (0, 2)))).numpy())
```
"""
if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
return F.Shrink.apply(self, arg=tuple(shrink_arg))
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], value:float=0.0) -> Tensor:
"""