mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add ceil_mode for avg_pool and max_pool (#7579)
* wip pool * check CI for remove alternative implementation * Revert "check CI for remove alternative implementation" This reverts commit7b1bb900e5. * fix test * tests tests tests * slap a resolve on it * fix comment * a little simpler pool * check CI for removal again * Revert "check CI for removal again" This reverts commitbe798b7857. * small * update * some ez tests * english * clean up code * fix ruff * how did I +25 lines? * small clean ups * moar clean ups * try test_avgpool2d_failure2 in CI * final clean up * exclude bug fix * avg underscore pool * no more edge case stuff * add better comments for explanation * add test cases for decreasing end padding * address feedback * improve test coverage * tiny more polish as we wait for lines :D * more readable code ordering * add to documentation * oops * set to False instead --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -2033,6 +2033,20 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation))
|
||||
|
||||
def test_max_pool2d_ceil_mode(self):
|
||||
shape = (1,1,6,6)
|
||||
for ksz in [(3,3), 3, (3,2), 4]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True))
|
||||
|
||||
def test_max_pool2d_ceil_mode_output_size_reduce_by_one(self):
|
||||
# sliding window ignored from end region
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
shape = (32,2,111,28)
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
@@ -2062,6 +2076,34 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), rtol=1e-5)
|
||||
|
||||
def test_avg_pool2d_ceil_mode(self):
|
||||
shape = (1,1,6,6)
|
||||
for ksz in [(3,3), 3, (3,2), 4]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False), rtol=1e-5)
|
||||
|
||||
def test_avg_pool2d_ceil_mode_output_size_reduce_by_one(self):
|
||||
# sliding window ignored from end region
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
|
||||
|
||||
def test_avg_pool2d_ceil_mode_include_pad(self):
|
||||
shape = (1,1,6,6)
|
||||
for ksz in [(3,3), 3, (3,2), 4]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True), rtol=1e-5)
|
||||
|
||||
def test_avg_pool2d_ceil_mode_include_pad_output_size_reduce_by_one(self):
|
||||
# sliding window ignored from end region
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True))
|
||||
|
||||
def test_global_avg_pool2d(self):
|
||||
helper_test_op([(32,2,111,28)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
|
||||
|
||||
@@ -2003,11 +2003,27 @@ class Tensor(SimpleMathTrait):
|
||||
def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
|
||||
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
||||
|
||||
def _ceil_mode_padding2d(self,k_:Tuple[sint, ...], s_:Union[Tuple[int, ...], int], d_:Union[Tuple[int, ...], int],
|
||||
p_:Union[Tuple[int, ...], int]) -> Sequence[int]:
|
||||
(d_,s_,p_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_,p_)), self.shape[-len(k_):]
|
||||
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
|
||||
o_ = [ceildiv(i+2*p - (d*(k-1)+1), s) + 1 for i,d,k,s,p in zip(i_,d_,k_,s_,p_)]
|
||||
pads = list(self._padding2d(p_, len(k_)))
|
||||
# we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
|
||||
# `s*(o-1) + (d*(k-1)+1) - (i+2*p)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
|
||||
# we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
|
||||
# `smax(s*(o-1) - (p+i-1), 0)` -> last_sliding_window_start - (left_pad + input_size - zero_offset)
|
||||
for dim,(o,i,s,p,k,d) in enumerate(zip(o_,i_,s_,p_,k_,d_)): pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+2*p) - smax(s*(o-1) - (p+i-1), 0)
|
||||
return pads
|
||||
|
||||
# NOTE: these work for more than 2D
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
|
||||
"""
|
||||
Applies average pooling over a tensor.
|
||||
|
||||
When `ceil_mode` is set to True, output shape will be determined using ceil division.
|
||||
When `count_include_pad` is set to False, zero padding will not be included in the averaging calculation.
|
||||
|
||||
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
||||
|
||||
See: https://paperswithcode.com/method/average-pooling
|
||||
@@ -2017,17 +2033,30 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.avg_pool2d().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.avg_pool2d(ceil_mode=True).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.avg_pool2d(padding=1).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.avg_pool2d(padding=1, count_include_pad=False).numpy())
|
||||
```
|
||||
"""
|
||||
padding_, axis = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))), tuple(range(-len(k_), 0))
|
||||
def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
||||
return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
|
||||
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
|
||||
reg_pads, ceil_pads = self._padding2d(padding,len(k_)), self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding)
|
||||
def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
||||
if not count_include_pad:
|
||||
pads = ceil_pads if ceil_mode else reg_pads
|
||||
return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
|
||||
if not ceil_mode: return pool(self, reg_pads).mean(axis)
|
||||
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
|
||||
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
|
||||
"""
|
||||
Applies max pooling over a tensor.
|
||||
|
||||
When `ceil_mode` is set to True, output shape will be determined using ceil division.
|
||||
|
||||
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
||||
|
||||
See: https://paperswithcode.com/method/max-pooling
|
||||
@@ -2037,11 +2066,15 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.max_pool2d().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.max_pool2d(ceil_mode=True).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.max_pool2d(padding=1).numpy())
|
||||
```
|
||||
"""
|
||||
padding_ = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2)))
|
||||
return self.pad(padding_, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
||||
k_ = make_tuple(kernel_size, 2)
|
||||
pads = self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding) if ceil_mode else self._padding2d(padding, len(k_))
|
||||
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|Tuple[int, ...]=0,
|
||||
acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user