enable argmax tests for METAL/WEBGPU in CI (#3027)

not sure why it was skipped but works now in CI
This commit is contained in:
chenyu
2024-01-05 21:43:00 -05:00
committed by GitHub
parent 2a2d3233d2
commit 138c17c094

View File

@@ -466,7 +466,6 @@ class TestOps(unittest.TestCase):
helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2), atol=1e-6)
helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=-1), lambda x: Tensor.cumsum(x, axis=-1), atol=1e-6)
@unittest.skipIf(CI and Device.DEFAULT in {"METAL", "WEBGPU"}, "fails in CI, works locally")
def test_argmax(self):
self.assertEqual(torch.Tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max
helper_test_op([(10,20)], lambda x: x.argmax(), lambda x: x.argmax(), forward_only=True)
@@ -474,7 +473,6 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,20)], lambda x: x.argmax(1, False), lambda x: x.argmax(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, True), lambda x: x.argmax(1, True), forward_only=True)
@unittest.skipIf(CI and Device.DEFAULT in {"METAL", "WEBGPU"}, "fails in CI, works locally")
def test_argmin(self):
self.assertEqual(torch.Tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
helper_test_op([(10,20)], lambda x: x.argmin(), lambda x: x.argmin(), forward_only=True)