diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 87fd9a1203..80fc038c76 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -479,6 +479,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_ "aten.fmax.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.maximum(input, other))), "aten.fmin.out": lambda input,other: Tensor.where(input.isnan() & ~other.isnan(), other, Tensor.where(~input.isnan() & other.isnan(), input, Tensor.minimum(input, other))), "aten.amax.out": lambda self,dim=None: self.max(axis=dim), + "aten.amin.out": lambda self,dim=None: self.min(axis=dim), # TODO: this gets the shape wrong #"aten.arange.start_out": Tensor.arange, "aten.lerp.Scalar_out": Tensor.lerp, diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index 236da279b9..10c077001d 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -103,6 +103,27 @@ class TestTorchBackend(unittest.TestCase): expected = np.array([[4.7, 12.9, 12.3], [16.9, 24.9, 23.6]], dtype=np.float32) np.testing.assert_equal(y3.cpu().numpy(), expected) + + def test_amin(self): + x = torch.tensor([[[ 1.5, 2.3, 3.1, 4.7], + [ 5.2, 6.8, 7.4, 12.9], + [ 9.0, 12.3, 11.6, 10.1]], + [[13.2, 16.9, 15.5, 14.1], + [17.1, 24.9, 19.8, 20.2], + [21.0, 22.3, 23.6, 18.4]]], device=device) + + y1 = torch.amin(x) + expected = np.array([1.5], dtype=np.float32) + np.testing.assert_equal(y1.cpu().numpy(), expected) + + y2 = torch.amin(x, dim=(1,2)) + expected = np.array([1.5, 13.2], dtype=np.float32) + np.testing.assert_equal(y2.cpu().numpy(), expected) + + y3 = torch.amin(x, dim=2) + expected = np.array([[1.5, 5.2, 9.0], [13.2, 17.1, 18.4]], dtype=np.float32) + np.testing.assert_equal(y3.cpu().numpy(), expected) + def test_isfinite(self): a = torch.ones(4, device=device) np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True])