mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
argmax(axis=None) is argmax.flatten().argmax(0) (#5090)
removed the alternative code path
This commit is contained in:
@@ -1518,9 +1518,7 @@ class Tensor:
|
||||
print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
|
||||
```
|
||||
"""
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
|
||||
return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
|
||||
if axis is None: return self.flatten().argmax(0)
|
||||
axis = self._resolve_dim(axis)
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
|
||||
Reference in New Issue
Block a user