mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add ceil_mode for avg_pool and max_pool (#7579)
* wip pool * check CI for remove alternative implementation * Revert "check CI for remove alternative implementation" This reverts commit7b1bb900e5. * fix test * tests tests tests * slap a resolve on it * fix comment * a little simpler pool * check CI for removal again * Revert "check CI for removal again" This reverts commitbe798b7857. * small * update * some ez tests * english * clean up code * fix ruff * how did I +25 lines? * small clean ups * moar clean ups * try test_avgpool2d_failure2 in CI * final clean up * exclude bug fix * avg underscore pool * no more edge case stuff * add better comments for explanation * add test cases for decreasing end padding * address feedback * improve test coverage * tiny more polish as we wait for lines :D * more readable code ordering * add to documentation * oops * set to False instead --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -2033,6 +2033,20 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation))
|
||||
|
||||
def test_max_pool2d_ceil_mode(self):
|
||||
shape = (1,1,6,6)
|
||||
for ksz in [(3,3), 3, (3,2), 4]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True))
|
||||
|
||||
def test_max_pool2d_ceil_mode_output_size_reduce_by_one(self):
|
||||
# sliding window ignored from end region
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
shape = (32,2,111,28)
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
@@ -2062,6 +2076,34 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), rtol=1e-5)
|
||||
|
||||
def test_avg_pool2d_ceil_mode(self):
|
||||
shape = (1,1,6,6)
|
||||
for ksz in [(3,3), 3, (3,2), 4]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False), rtol=1e-5)
|
||||
|
||||
def test_avg_pool2d_ceil_mode_output_size_reduce_by_one(self):
|
||||
# sliding window ignored from end region
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
|
||||
|
||||
def test_avg_pool2d_ceil_mode_include_pad(self):
|
||||
shape = (1,1,6,6)
|
||||
for ksz in [(3,3), 3, (3,2), 4]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True), rtol=1e-5)
|
||||
|
||||
def test_avg_pool2d_ceil_mode_include_pad_output_size_reduce_by_one(self):
|
||||
# sliding window ignored from end region
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True))
|
||||
|
||||
def test_global_avg_pool2d(self):
|
||||
helper_test_op([(32,2,111,28)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
|
||||
|
||||
Reference in New Issue
Block a user