clean the long lines in avg_pool2d and max_pool2d (#5091)

This commit is contained in:
chenyu
2024-06-21 14:46:56 -04:00
committed by GitHub
parent a971dc6218
commit 00593d6095

View File

@@ -1623,8 +1623,9 @@ class Tensor:
print(t.avg_pool2d().numpy())
```
"""
return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
kernel_size = make_pair(kernel_size)
return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(-len(kernel_size), 0)))
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
"""
Applies max pooling over a tensor.
@@ -1638,8 +1639,8 @@ class Tensor:
print(t.max_pool2d().numpy())
```
"""
return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
kernel_size = make_pair(kernel_size)
return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(-len(kernel_size), 0)))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
"""