Argmax/Argmin Feature (#1576)

* implemented argmax and argmin

* lint

* lint

* match torch behaviour

* format

* removed flip
This commit is contained in:
Umut Zengin
2023-08-21 04:46:46 +03:00
committed by GitHub
parent 1900acda09
commit 35bf21276f
2 changed files with 22 additions and 0 deletions

View File

@@ -449,6 +449,14 @@ class Tensor:
m, _, ss = self._softmax(axis)
return m - ss.log()
def argmax(self, axis=None, keepdim=False):
if axis is None: return prod(self.shape) - ((self == self.max(axis)).flatten() * Tensor.arange(prod(self.shape)-1,-1,-1)).max() - 1
axis = axis + self.ndim if axis < 0 else axis
m = self == (self.max(axis=axis, keepdim=keepdim) if keepdim else self.max(axis=axis, keepdim=keepdim).unsqueeze(axis))
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(*[1]*axis, self.shape[axis], *[1]*(self.ndim-(axis+1)))
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
# ***** processing ops *****
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor: