Add return_indices to max_pool (#9506)

* wow argmax is so good

* 1 less line

* clean up and better variable names

* is this torch thing right...?

* add more tests

* slap a TODO on it

* clean ups

* prettier looking code and fix ceil mode test

* add return types and some docs

* ok that was a bad example since indices == value, just no example
This commit is contained in:
geohotstan
2025-03-20 03:25:37 +08:00
committed by GitHub
parent 189f62d44f
commit 8c0d0a122c
4 changed files with 51 additions and 16 deletions

View File

@@ -409,11 +409,9 @@ def get_onnx_ops():
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
storage_order:int=0, strides:list[int]|int=1):
ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode)
# tests expect indices with int64 dtype
# TODO: if there are repeated values, this is wrong
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
ret, idx = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, return_indices=True)
return ret, idx.transpose(-2, -1).cast(dtypes.int64) if storage_order else idx.cast(dtypes.int64)
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):

View File

@@ -162,15 +162,14 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False):
# TODO: supprt stride [] in tinygrad?
if stride is not None and len(stride) == 0: stride = None
# TODO: support return_indices in tinygrad
ret = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode)
# TODO: this is wrong
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
ret, idx = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode, return_indices=True)
return (wrap(ret), wrap(idx.cast(dtypes.int64)))
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
if stride is not None and len(stride) == 0: stride = None
# TODO: utilize input indices once they are correct
# TODO: implement maxunpool
self_ = unwrap(self)
out = Tensor.max_pool2d(self_, kernel_size, stride, dilation, padding, ceil_mode)
return wrap(out.gradient(self_, gradient=unwrap(grad_out))[0])

View File

@@ -2327,6 +2327,35 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
def test_max_pool2d_return_indices(self):
# batch and multi-channel
helper_test_op([(2,3,6,6)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1], forward_only=True)
# dilation
helper_test_op([(1,1,10,10)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1], forward_only=True)
# padding
helper_test_op([(1,1,5,5)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1], forward_only=True)
# ceil mode padding
helper_test_op([(1, 1, 7, 7)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1],
forward_only=True)
# global maxpool
helper_test_op([(1,1,12,13)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1],
forward_only=True)
# multiple identical values in same window and overlapping windows
helper_test_op(None,
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1],
vals=[[[[[1]*6]*6]]], forward_only=True) # Tensor.ones(1,1,6,6)
def test_avg_pool2d(self):
shape = (32,2,111,28)
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:

View File

@@ -2110,7 +2110,7 @@ class Tensor(SimpleMathTrait):
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
ceil_mode=False, count_include_pad=True):
ceil_mode=False, count_include_pad=True) -> Tensor:
"""
Applies average pooling over a tensor.
@@ -2158,7 +2158,7 @@ class Tensor(SimpleMathTrait):
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
ceil_mode=False):
ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]:
"""
Applies max pooling over a tensor.
@@ -2175,6 +2175,7 @@ class Tensor(SimpleMathTrait):
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
When `return_indices` is set to `True`, the argmax will be returned along with the max values.
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
@@ -2191,9 +2192,16 @@ class Tensor(SimpleMathTrait):
print(t.max_pool2d(padding=1).numpy())
```
"""
pads = self._resolve_pool_pads(padding, len(k_ := make_tuple(kernel_size, 2)))
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
pads = self._resolve_pool_pads(padding, len(k_))
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation)
if not return_indices: return pooled.max(axis)
spatial_sz = math.prod(spatial_shape := self.shape[-len(k_):])
idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape)
m = pooled == pooled.max(axis, keepdim=True)
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
return pooled.max(axis), spatial_sz - idx.max(axis)
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
dtype:DTypeLike|None=None) -> Tensor:
@@ -2577,7 +2585,7 @@ class Tensor(SimpleMathTrait):
return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
def sort(self, dim:int=-1, descending:bool=False):
def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]:
"""
Performs a bitonic sort on the tensor along the specified dimension.
@@ -2621,14 +2629,15 @@ class Tensor(SimpleMathTrait):
x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim)
x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, orig_len) if i == dim else None for i in range(x.ndim)))
# compute indices for sorted values
idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))).expand(x.shape)
idx = Tensor.arange(orig_len, requires_grad=False, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim)))
idx = idx.expand(x.shape)
def compute_counts(t:Tensor): return ((idx.unsqueeze(dim) <= idx.unsqueeze(dim+1)) & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1)
count_orig, count_sorted = compute_counts(self), compute_counts(x)
cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim))
idx = (cond * idx.unsqueeze(dim+1)).sum(dim)
return x, idx
def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True):
def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True) -> tuple[Tensor, Tensor]:
"""
Computes the top-k elements of the tensor along the specified `dim`.