mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Argmax/Argmin Feature (#1576)
* implemented argmax and argmin * lint * lint * match torch behaviour * format * removed flip
This commit is contained in:
@@ -392,6 +392,20 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1), atol=1e-6)
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2), atol=1e-6)
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=-1), lambda x: Tensor.cumsum(x, axis=-1), atol=1e-6)
|
||||
|
||||
def test_argmax(self):
|
||||
self.assertEqual(torch.Tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(), lambda x: x.argmax(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(0, False), lambda x: x.argmax(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, False), lambda x: x.argmax(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, True), lambda x: x.argmax(1, True), forward_only=True)
|
||||
def test_argmin(self):
|
||||
self.assertEqual(torch.Tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(), lambda x: x.argmin(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(0, False), lambda x: x.argmin(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, False), lambda x: x.argmin(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, True), lambda x: x.argmin(1, True), forward_only=True)
|
||||
|
||||
def test_matmul_simple(self):
|
||||
helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_matmul(self):
|
||||
|
||||
@@ -449,6 +449,14 @@ class Tensor:
|
||||
m, _, ss = self._softmax(axis)
|
||||
return m - ss.log()
|
||||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
if axis is None: return prod(self.shape) - ((self == self.max(axis)).flatten() * Tensor.arange(prod(self.shape)-1,-1,-1)).max() - 1
|
||||
axis = axis + self.ndim if axis < 0 else axis
|
||||
m = self == (self.max(axis=axis, keepdim=keepdim) if keepdim else self.max(axis=axis, keepdim=keepdim).unsqueeze(axis))
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(*[1]*axis, self.shape[axis], *[1]*(self.ndim-(axis+1)))
|
||||
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
|
||||
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user