diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 74a7ec4ff3..9ec0957755 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1001,8 +1001,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.shrink((((0, 2), (0, 2)))).numpy()) ``` """ - if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self - return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) + if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self + return F.Shrink.apply(self, arg=tuple(shrink_arg)) def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], value:float=0.0) -> Tensor: """