Add return_indices to max_pool (#9506)

* wow argmax is so good

* 1 less line

* clean up and better variable names

* is this torch thing right...?

* add more tests

* slap a TODO on it

* clean ups

* prettier looking code and fix ceil mode test

* add return types and some docs

* ok that was a bad example since indices == value, just no example
This commit is contained in:
geohotstan
2025-03-20 03:25:37 +08:00
committed by GitHub
parent 189f62d44f
commit 8c0d0a122c
4 changed files with 51 additions and 16 deletions

View File

@@ -2327,6 +2327,35 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
def test_max_pool2d_return_indices(self):
# batch and multi-channel
helper_test_op([(2,3,6,6)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1], forward_only=True)
# dilation
helper_test_op([(1,1,10,10)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1], forward_only=True)
# padding
helper_test_op([(1,1,5,5)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1], forward_only=True)
# ceil mode padding
helper_test_op([(1, 1, 7, 7)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1],
forward_only=True)
# global maxpool
helper_test_op([(1,1,12,13)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1],
forward_only=True)
# multiple identical values in same window and overlapping windows
helper_test_op(None,
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1],
vals=[[[[[1]*6]*6]]], forward_only=True) # Tensor.ones(1,1,6,6)
def test_avg_pool2d(self):
shape = (32,2,111,28)
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: