mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Add return_indices to max_pool (#9506)
* wow argmax is so good * 1 less line * clean up and better variable names * is this torch thing right...? * add more tests * slap a TODO on it * clean ups * prettier looking code and fix ceil mode test * add return types and some docs * ok that was a bad example since indices == value, just no example
This commit is contained in:
@@ -409,11 +409,9 @@ def get_onnx_ops():
|
||||
|
||||
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
|
||||
storage_order:int=0, strides:list[int]|int=1):
|
||||
ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode)
|
||||
# tests expect indices with int64 dtype
|
||||
# TODO: if there are repeated values, this is wrong
|
||||
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
|
||||
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
|
||||
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
|
||||
ret, idx = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, return_indices=True)
|
||||
return ret, idx.transpose(-2, -1).cast(dtypes.int64) if storage_order else idx.cast(dtypes.int64)
|
||||
|
||||
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
|
||||
@@ -162,15 +162,14 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
|
||||
def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False):
|
||||
# TODO: supprt stride [] in tinygrad?
|
||||
if stride is not None and len(stride) == 0: stride = None
|
||||
# TODO: support return_indices in tinygrad
|
||||
ret = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode)
|
||||
# TODO: this is wrong
|
||||
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
|
||||
ret, idx = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode, return_indices=True)
|
||||
return (wrap(ret), wrap(idx.cast(dtypes.int64)))
|
||||
|
||||
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
|
||||
def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
|
||||
if stride is not None and len(stride) == 0: stride = None
|
||||
# TODO: utilize input indices once they are correct
|
||||
# TODO: implement maxunpool
|
||||
self_ = unwrap(self)
|
||||
out = Tensor.max_pool2d(self_, kernel_size, stride, dilation, padding, ceil_mode)
|
||||
return wrap(out.gradient(self_, gradient=unwrap(grad_out))[0])
|
||||
|
||||
@@ -2327,6 +2327,35 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
|
||||
|
||||
def test_max_pool2d_return_indices(self):
|
||||
# batch and multi-channel
|
||||
helper_test_op([(2,3,6,6)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1], forward_only=True)
|
||||
# dilation
|
||||
helper_test_op([(1,1,10,10)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1], forward_only=True)
|
||||
# padding
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1], forward_only=True)
|
||||
# ceil mode padding
|
||||
helper_test_op([(1, 1, 7, 7)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1],
|
||||
forward_only=True)
|
||||
# global maxpool
|
||||
helper_test_op([(1,1,12,13)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1],
|
||||
forward_only=True)
|
||||
# multiple identical values in same window and overlapping windows
|
||||
helper_test_op(None,
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1].type(torch.int32),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1],
|
||||
vals=[[[[[1]*6]*6]]], forward_only=True) # Tensor.ones(1,1,6,6)
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
shape = (32,2,111,28)
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
|
||||
@@ -2110,7 +2110,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# NOTE: these work for more than 2D
|
||||
def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
ceil_mode=False, count_include_pad=True):
|
||||
ceil_mode=False, count_include_pad=True) -> Tensor:
|
||||
"""
|
||||
Applies average pooling over a tensor.
|
||||
|
||||
@@ -2158,7 +2158,7 @@ class Tensor(SimpleMathTrait):
|
||||
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
|
||||
|
||||
def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
ceil_mode=False):
|
||||
ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Applies max pooling over a tensor.
|
||||
|
||||
@@ -2175,6 +2175,7 @@ class Tensor(SimpleMathTrait):
|
||||
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
|
||||
|
||||
When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
|
||||
When `return_indices` is set to `True`, the argmax will be returned along with the max values.
|
||||
|
||||
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
||||
|
||||
@@ -2191,9 +2192,16 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.max_pool2d(padding=1).numpy())
|
||||
```
|
||||
"""
|
||||
pads = self._resolve_pool_pads(padding, len(k_ := make_tuple(kernel_size, 2)))
|
||||
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
|
||||
pads = self._resolve_pool_pads(padding, len(k_))
|
||||
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
|
||||
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
||||
pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
||||
if not return_indices: return pooled.max(axis)
|
||||
spatial_sz = math.prod(spatial_shape := self.shape[-len(k_):])
|
||||
idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape)
|
||||
m = pooled == pooled.max(axis, keepdim=True)
|
||||
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
||||
return pooled.max(axis), spatial_sz - idx.max(axis)
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
dtype:DTypeLike|None=None) -> Tensor:
|
||||
@@ -2577,7 +2585,7 @@ class Tensor(SimpleMathTrait):
|
||||
return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
|
||||
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
|
||||
|
||||
def sort(self, dim:int=-1, descending:bool=False):
|
||||
def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Performs a bitonic sort on the tensor along the specified dimension.
|
||||
|
||||
@@ -2621,14 +2629,15 @@ class Tensor(SimpleMathTrait):
|
||||
x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim)
|
||||
x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, orig_len) if i == dim else None for i in range(x.ndim)))
|
||||
# compute indices for sorted values
|
||||
idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))).expand(x.shape)
|
||||
idx = Tensor.arange(orig_len, requires_grad=False, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim)))
|
||||
idx = idx.expand(x.shape)
|
||||
def compute_counts(t:Tensor): return ((idx.unsqueeze(dim) <= idx.unsqueeze(dim+1)) & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1)
|
||||
count_orig, count_sorted = compute_counts(self), compute_counts(x)
|
||||
cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim))
|
||||
idx = (cond * idx.unsqueeze(dim+1)).sum(dim)
|
||||
return x, idx
|
||||
|
||||
def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True):
|
||||
def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Computes the top-k elements of the tensor along the specified `dim`.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user