add ceil_mode for avg_pool and max_pool (#7579)

* wip pool

* check CI for remove alternative implementation

* Revert "check CI for remove alternative implementation"

This reverts commit 7b1bb900e5.

* fix test

* tests tests tests

* slap a resolve on it

* fix comment

* a little simpler pool

* check CI for removal again

* Revert "check CI for removal again"

This reverts commit be798b7857.

* small

* update

* some ez tests

* english

* clean up code

* fix ruff

* how did I +25 lines?

* small clean ups

* moar clean ups

* try test_avgpool2d_failure2 in CI

* final clean up

* exclude bug fix

* avg underscore pool

* no more edge case stuff

* add better comments for explanation

* add test cases for decreasing end padding

* address feedback

* improve test coverage

* tiny more polish as we wait for lines :D

* more readable code ordering

* add to documentation

* oops

* set to False instead

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2024-12-06 21:34:14 +08:00
committed by GitHub
parent b73d9a7d24
commit a684d72e55
2 changed files with 82 additions and 7 deletions

View File

@@ -2033,6 +2033,20 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation))
def test_max_pool2d_ceil_mode(self):
shape = (1,1,6,6)
for ksz in [(3,3), 3, (3,2), 4]:
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True))
def test_max_pool2d_ceil_mode_output_size_reduce_by_one(self):
# sliding window ignored from end region
helper_test_op([(1,1,5,5)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
def test_avg_pool2d(self):
shape = (32,2,111,28)
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
@@ -2062,6 +2076,34 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), rtol=1e-5)
def test_avg_pool2d_ceil_mode(self):
shape = (1,1,6,6)
for ksz in [(3,3), 3, (3,2), 4]:
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False), rtol=1e-5)
def test_avg_pool2d_ceil_mode_output_size_reduce_by_one(self):
# sliding window ignored from end region
helper_test_op([(1,1,5,5)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
def test_avg_pool2d_ceil_mode_include_pad(self):
shape = (1,1,6,6)
for ksz in [(3,3), 3, (3,2), 4]:
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True), rtol=1e-5)
def test_avg_pool2d_ceil_mode_include_pad_output_size_reduce_by_one(self):
# sliding window ignored from end region
helper_test_op([(1,1,5,5)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True),
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True))
def test_global_avg_pool2d(self):
helper_test_op([(32,2,111,28)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),

View File

@@ -2003,11 +2003,27 @@ class Tensor(SimpleMathTrait):
def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
def _ceil_mode_padding2d(self,k_:Tuple[sint, ...], s_:Union[Tuple[int, ...], int], d_:Union[Tuple[int, ...], int],
p_:Union[Tuple[int, ...], int]) -> Sequence[int]:
(d_,s_,p_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_,p_)), self.shape[-len(k_):]
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
o_ = [ceildiv(i+2*p - (d*(k-1)+1), s) + 1 for i,d,k,s,p in zip(i_,d_,k_,s_,p_)]
pads = list(self._padding2d(p_, len(k_)))
# we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
# `s*(o-1) + (d*(k-1)+1) - (i+2*p)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
# we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
# `smax(s*(o-1) - (p+i-1), 0)` -> last_sliding_window_start - (left_pad + input_size - zero_offset)
for dim,(o,i,s,p,k,d) in enumerate(zip(o_,i_,s_,p_,k_,d_)): pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+2*p) - smax(s*(o-1) - (p+i-1), 0)
return pads
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
"""
Applies average pooling over a tensor.
When `ceil_mode` is set to True, output shape will be determined using ceil division.
When `count_include_pad` is set to False, zero padding will not be included in the averaging calculation.
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
See: https://paperswithcode.com/method/average-pooling
@@ -2017,17 +2033,30 @@ class Tensor(SimpleMathTrait):
print(t.avg_pool2d().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.avg_pool2d(ceil_mode=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.avg_pool2d(padding=1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.avg_pool2d(padding=1, count_include_pad=False).numpy())
```
"""
padding_, axis = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))), tuple(range(-len(k_), 0))
def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
reg_pads, ceil_pads = self._padding2d(padding,len(k_)), self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding)
def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
if not count_include_pad:
pads = ceil_pads if ceil_mode else reg_pads
return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
if not ceil_mode: return pool(self, reg_pads).mean(axis)
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
"""
Applies max pooling over a tensor.
When `ceil_mode` is set to True, output shape will be determined using ceil division.
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
See: https://paperswithcode.com/method/max-pooling
@@ -2037,11 +2066,15 @@ class Tensor(SimpleMathTrait):
print(t.max_pool2d().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.max_pool2d(ceil_mode=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.max_pool2d(padding=1).numpy())
```
"""
padding_ = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2)))
return self.pad(padding_, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
k_ = make_tuple(kernel_size, 2)
pads = self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding) if ceil_mode else self._padding2d(padding, len(k_))
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|Tuple[int, ...]=0,
acc_dtype:Optional[DTypeLike]=None) -> Tensor: