From 8c0d0a122c89ea79e0c4c6594dd201bb1cc33575 Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Thu, 20 Mar 2025 03:25:37 +0800 Subject: [PATCH] Add return_indices to max_pool (#9506) * wow argmax is so good * 1 less line * clean up and better variable names * is this torch thing right...? * add more tests * slap a TODO on it * clean ups * prettier looking code and fix ceil mode test * add return types and some docs * ok that was a bad example since indices == value, just no example --- extra/onnx.py | 8 +++----- extra/torch_backend/backend.py | 7 +++---- test/test_ops.py | 29 +++++++++++++++++++++++++++++ tinygrad/tensor.py | 23 ++++++++++++++++------- 4 files changed, 51 insertions(+), 16 deletions(-) diff --git a/extra/onnx.py b/extra/onnx.py index 59ec591c0e..5a6a64534a 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -409,11 +409,9 @@ def get_onnx_ops(): def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, storage_order:int=0, strides:list[int]|int=1): - ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode) - # tests expect indices with int64 dtype - # TODO: if there are repeated values, this is wrong - indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape) - return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices + pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) + ret, idx = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, return_indices=True) + return ret, idx.transpose(-2, -1).cast(dtypes.int64) if storage_order else idx.cast(dtypes.int64) def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index b7ebdc2505..9dab77f345 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -162,15 +162,14 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False): # TODO: supprt stride [] in tinygrad? if stride is not None and len(stride) == 0: stride = None - # TODO: support return_indices in tinygrad - ret = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode) - # TODO: this is wrong - return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64))) + ret, idx = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode, return_indices=True) + return (wrap(ret), wrap(idx.cast(dtypes.int64))) @torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone") def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None): if stride is not None and len(stride) == 0: stride = None # TODO: utilize input indices once they are correct + # TODO: implement maxunpool self_ = unwrap(self) out = Tensor.max_pool2d(self_, kernel_size, stride, dilation, padding, ceil_mode) return wrap(out.gradient(self_, gradient=unwrap(grad_out))[0]) diff --git a/test/test_ops.py b/test/test_ops.py index 9a6b8e7634..34700276a4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2327,6 +2327,35 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True), lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True)) + def test_max_pool2d_return_indices(self): + # batch and multi-channel + helper_test_op([(2,3,6,6)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1].type(torch.int32), + lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1], forward_only=True) + # dilation + helper_test_op([(1,1,10,10)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1].type(torch.int32), + lambda x: Tensor.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1], forward_only=True) + # padding + helper_test_op([(1,1,5,5)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1].type(torch.int32), + lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1], forward_only=True) + # ceil mode padding + helper_test_op([(1, 1, 7, 7)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1].type(torch.int32), + lambda x: Tensor.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1], + forward_only=True) + # global maxpool + helper_test_op([(1,1,12,13)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1].type(torch.int32), + lambda x: Tensor.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1], + forward_only=True) + # multiple identical values in same window and overlapping windows + helper_test_op(None, + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1].type(torch.int32), + lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1], + vals=[[[[[1]*6]*6]]], forward_only=True) # Tensor.ones(1,1,6,6) + def test_avg_pool2d(self): shape = (32,2,111,28) for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7935484643..6b8ddab448 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2110,7 +2110,7 @@ class Tensor(SimpleMathTrait): # NOTE: these work for more than 2D def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, - ceil_mode=False, count_include_pad=True): + ceil_mode=False, count_include_pad=True) -> Tensor: """ Applies average pooling over a tensor. @@ -2158,7 +2158,7 @@ class Tensor(SimpleMathTrait): return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis) def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, - ceil_mode=False): + ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]: """ Applies max pooling over a tensor. @@ -2175,6 +2175,7 @@ class Tensor(SimpleMathTrait): `(padding_left, padding_right, padding_top, padding_bottom, ...)`. When `ceil_mode` is set to `True`, output shape will be determined using ceil division. + When `return_indices` is set to `True`, the argmax will be returned along with the max values. NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions. @@ -2191,9 +2192,16 @@ class Tensor(SimpleMathTrait): print(t.max_pool2d(padding=1).numpy()) ``` """ - pads = self._resolve_pool_pads(padding, len(k_ := make_tuple(kernel_size, 2))) + axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0)) + pads = self._resolve_pool_pads(padding, len(k_)) if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation) - return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0))) + pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation) + if not return_indices: return pooled.max(axis) + spatial_sz = math.prod(spatial_shape := self.shape[-len(k_):]) + idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape) + m = pooled == pooled.max(axis, keepdim=True) + idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation) + return pooled.max(axis), spatial_sz - idx.max(axis) def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0, dtype:DTypeLike|None=None) -> Tensor: @@ -2577,7 +2585,7 @@ class Tensor(SimpleMathTrait): return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count) raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'") - def sort(self, dim:int=-1, descending:bool=False): + def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]: """ Performs a bitonic sort on the tensor along the specified dimension. @@ -2621,14 +2629,15 @@ class Tensor(SimpleMathTrait): x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim) x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, orig_len) if i == dim else None for i in range(x.ndim))) # compute indices for sorted values - idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))).expand(x.shape) + idx = Tensor.arange(orig_len, requires_grad=False, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))) + idx = idx.expand(x.shape) def compute_counts(t:Tensor): return ((idx.unsqueeze(dim) <= idx.unsqueeze(dim+1)) & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1) count_orig, count_sorted = compute_counts(self), compute_counts(x) cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim)) idx = (cond * idx.unsqueeze(dim+1)).sum(dim) return x, idx - def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True): + def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True) -> tuple[Tensor, Tensor]: """ Computes the top-k elements of the tensor along the specified `dim`.