mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Add amin support to Tensor operations in Torch backend (#11290)
* intiger div mod fix
* Revert "intiger div mod fix"
This reverts commit d5d2f201bf.
* feat arg_min support
* tets update
* test fix
This commit is contained in:
@@ -479,6 +479,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
||||
"aten.fmax.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.maximum(input, other))),
|
||||
"aten.fmin.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.minimum(input, other))),
|
||||
"aten.amax.out": lambda self,dim=None: self.max(axis=dim),
|
||||
"aten.amin.out": lambda self,dim=None: self.min(axis=dim),
|
||||
# TODO: this gets the shape wrong
|
||||
#"aten.arange.start_out": Tensor.arange,
|
||||
"aten.lerp.Scalar_out": Tensor.lerp,
|
||||
|
||||
@@ -103,6 +103,27 @@ class TestTorchBackend(unittest.TestCase):
|
||||
expected = np.array([[4.7, 12.9, 12.3], [16.9, 24.9, 23.6]], dtype=np.float32)
|
||||
np.testing.assert_equal(y3.cpu().numpy(), expected)
|
||||
|
||||
|
||||
def test_amin(self):
|
||||
x = torch.tensor([[[ 1.5, 2.3, 3.1, 4.7],
|
||||
[ 5.2, 6.8, 7.4, 12.9],
|
||||
[ 9.0, 12.3, 11.6, 10.1]],
|
||||
[[13.2, 16.9, 15.5, 14.1],
|
||||
[17.1, 24.9, 19.8, 20.2],
|
||||
[21.0, 22.3, 23.6, 18.4]]], device=device)
|
||||
|
||||
y1 = torch.amin(x)
|
||||
expected = np.array([1.5], dtype=np.float32)
|
||||
np.testing.assert_equal(y1.cpu().numpy(), expected)
|
||||
|
||||
y2 = torch.amin(x, dim=(1,2))
|
||||
expected = np.array([1.5, 13.2], dtype=np.float32)
|
||||
np.testing.assert_equal(y2.cpu().numpy(), expected)
|
||||
|
||||
y3 = torch.amin(x, dim=2)
|
||||
expected = np.array([[1.5, 5.2, 9.0], [13.2, 17.1, 18.4]], dtype=np.float32)
|
||||
np.testing.assert_equal(y3.cpu().numpy(), expected)
|
||||
|
||||
def test_isfinite(self):
|
||||
a = torch.ones(4, device=device)
|
||||
np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True])
|
||||
|
||||
Reference in New Issue
Block a user