mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Add return_indices to max_pool (#9506)
* wow argmax is so good * 1 less line * clean up and better variable names * is this torch thing right...? * add more tests * slap a TODO on it * clean ups * prettier looking code and fix ceil mode test * add return types and some docs * ok that was a bad example since indices == value, just no example
This commit is contained in:
@@ -2327,6 +2327,35 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
|
||||
|
||||
def test_max_pool2d_return_indices(self):
|
||||
# batch and multi-channel
|
||||
helper_test_op([(2,3,6,6)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1], forward_only=True)
|
||||
# dilation
|
||||
helper_test_op([(1,1,10,10)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1], forward_only=True)
|
||||
# padding
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1], forward_only=True)
|
||||
# ceil mode padding
|
||||
helper_test_op([(1, 1, 7, 7)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1],
|
||||
forward_only=True)
|
||||
# global maxpool
|
||||
helper_test_op([(1,1,12,13)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1],
|
||||
forward_only=True)
|
||||
# multiple identical values in same window and overlapping windows
|
||||
helper_test_op(None,
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1],
|
||||
vals=[[[[[1]*6]*6]]], forward_only=True) # Tensor.ones(1,1,6,6)
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
shape = (32,2,111,28)
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
|
||||
Reference in New Issue
Block a user