From 2d7c28de6ac5dd6712152a50606f025b31999eca Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 22 Jul 2025 12:21:57 -0400 Subject: [PATCH] clean up dup lambdas in helper_test_exception (#11325) --- test/test_ops.py | 57 ++++++++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 575399a606..7634edda6e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -236,10 +236,10 @@ class TestOps(unittest.TestCase): helper_test_op([(3,3,3)], lambda x: x.unfold(1, 0, 8)) helper_test_op([(3,3,3,3,3)], lambda x: x.unfold(-1, 2, 2)) - self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError) - self.helper_test_exception([(8,)], lambda x: x.unfold(1, 8, 3), lambda x: x.unfold(1, 8, 3), expected=IndexError) - self.helper_test_exception([(8,)], lambda x: x.unfold(0, -1, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError) - self.helper_test_exception([(8,)], lambda x: x.unfold(0, 1, -1), lambda x: x.unfold(0, 9, 3), expected=RuntimeError) + self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), expected=RuntimeError) + self.helper_test_exception([(8,)], lambda x: x.unfold(1, 8, 3), expected=IndexError) + self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), expected=RuntimeError) + self.helper_test_exception([(8,)], lambda x: x.unfold(0, 1, -1), expected=RuntimeError) def test_meshgrid(self): x, xt = torch.tensor([0.,1.,2.], requires_grad=True), Tensor([0.,1.,2.], requires_grad=True) @@ -548,7 +548,7 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65), (45,65)], lambda x,y: x/y) helper_test_op([(), ()], lambda x,y: x/y) - @unittest.skipIf(AMD_LLVM, "AMD with LLVM backend generate rcp in FP division causes trunc/floor errors") + @unittest.skipIf(Device.DEFAULT == "AMD" and AMD_LLVM, "AMD with LLVM backend generate rcp in FP division causes trunc/floor errors") def test_div_rounding_mode(self): for denominator in [-10, -5, -3, -2, -1, 1, 2, 3, 5, 10]: # int numerator @@ -576,8 +576,7 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x,y: x.div(y, rounding_mode="trunc"), forward_only=True, vals=[[numerator], [denominator]]) helper_test_op(None, lambda x,y: x.div(y, rounding_mode="floor"), forward_only=True, vals=[[numerator], [denominator]]) - self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True, - vals=[[5], [0]], expected=RuntimeError) + self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True, vals=[[5], [0]], expected=RuntimeError) def test_div_int(self): helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5, 6, 7],[1, 2, 3]]) @@ -737,7 +736,7 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True) helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True) - self.helper_test_exception([(4), (4)], torch.bitwise_xor, Tensor.bitwise_xor, expected=RuntimeError) + self.helper_test_exception([(4), (4)], lambda x,y: x.bitwise_xor(y), expected=RuntimeError) def test_and(self): data = [[1,-8,1],[32,1,6]] @@ -754,7 +753,7 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: (1 < x) & (x < 2), forward_only=True, vals=[[1.2, 1.2, 1.2, 3.2]]) - self.helper_test_exception([(4), (4)], torch.bitwise_and, Tensor.bitwise_and, expected=RuntimeError) + self.helper_test_exception([(4), (4)], lambda x,y: x.bitwise_and(y), expected=RuntimeError) def test_or(self): data = [[1,-8,1],[32,1,6]] @@ -769,7 +768,7 @@ class TestOps(unittest.TestCase): ten0, ten1 = Tensor(data[0], dtype=dtypes.bool), Tensor(data[1], dtype=dtypes.bool) helper_test_op([], lambda: tor0|tor1, lambda: ten0|ten1, forward_only=True) - self.helper_test_exception([(4), (4)], torch.bitwise_or, Tensor.bitwise_or, expected=RuntimeError) + self.helper_test_exception([(4), (4)], lambda x,y: x.bitwise_or(y), expected=RuntimeError) def test_bitwise_not(self): data = [[1,-8,1],[32,1,6]] @@ -784,7 +783,7 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: tor.bitwise_not(), lambda: ten.bitwise_not(), forward_only=True) helper_test_op([], lambda: ~tor, lambda: ~ten, forward_only=True) - self.helper_test_exception([(4)], torch.bitwise_not, Tensor.bitwise_not, expected=RuntimeError) + self.helper_test_exception([(4)], lambda x: x.bitwise_not(), expected=RuntimeError) def test_lshift(self): data = [[0,1,2],[1<<8,1<<16,1<<31-1]] @@ -1131,7 +1130,7 @@ class TestOps(unittest.TestCase): value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3, largest=False) np.testing.assert_equal(value.numpy(), [0, 0, 0]) np.testing.assert_equal(indices.numpy(), [2, 4, 6]) - self.helper_test_exception([(4)], lambda x: x.topk(5), lambda x: x.topk(5), expected=(RuntimeError, ValueError)) + self.helper_test_exception([(4)], lambda x: x.topk(5), expected=(RuntimeError, ValueError)) def test_einsum(self): # matrix transpose @@ -1334,9 +1333,9 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: x.sum(0)) helper_test_op([()], lambda x: x.sum(-1)) helper_test_op([()], lambda x: x.sum(())) - self.helper_test_exception([(3,4,5,6)], lambda x: x.sum(5), lambda x: x.sum(5), expected=IndexError) - self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError) - self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError) + self.helper_test_exception([(3,4,5,6)], lambda x: x.sum(5), expected=IndexError) + self.helper_test_exception([()], lambda x: x.sum(1), expected=IndexError) + self.helper_test_exception([()], lambda x: x.sum((1,)), expected=IndexError) def test_sum_dtype_arg(self): helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(dtype=dtypes.float32)) @@ -1851,9 +1850,9 @@ class TestOps(unittest.TestCase): helper_test_op([(3,4,5,6)], lambda x: x.permute((3,2,1,0))) helper_test_op([(3,4,5,6)], lambda x: x.permute((-2,-1,1,0))) helper_test_op([()], lambda x: x.permute(())) - self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,2)), lambda x: x.permute((0,2)), expected=RuntimeError) - self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,1,2,3,3,3)), lambda x: x.permute((0,1,2,3,3,3)), expected=RuntimeError) - self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,0,1,2,3)), lambda x: x.permute((0,0,1,2,3)), expected=RuntimeError) + self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,2)), expected=RuntimeError) + self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,1,2,3,3,3)), expected=RuntimeError) + self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,0,1,2,3)), expected=RuntimeError) def test_reshape(self): helper_test_op([(4,3,6,6)], lambda x: x.reshape((12,6,6))) @@ -1864,8 +1863,8 @@ class TestOps(unittest.TestCase): helper_test_op([(1,)], lambda x: x.reshape(())) helper_test_op([()], lambda x: x.reshape((1,))) helper_test_op([()], lambda x: x.reshape((1,1,1))) - self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,2)), lambda x: x.reshape((-1,-1,2)), expected=RuntimeError) - self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,-1,2)), lambda x: x.reshape((-1,-1,-1,2)), expected=RuntimeError) + self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,2)), expected=RuntimeError) + self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,-1,2)), expected=RuntimeError) with self.assertRaises(ValueError): x = Tensor.ones((4,3,6,6)) @@ -1890,16 +1889,16 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: x.flip(())) helper_test_op([(1,)], lambda x: x.flip(())) helper_test_op([(4,3,6,6)], lambda x: x.flip(())) - self.helper_test_exception([(3,4)], lambda x: x.flip((0,0)), lambda x: x.flip((0,0)), expected=RuntimeError) - self.helper_test_exception([(3,4)], lambda x: x.flip((1,1)), lambda x: x.flip((1,1)), expected=RuntimeError) - self.helper_test_exception([(3,4)], lambda x: x.flip((1,-1)), lambda x: x.flip((1,-1)), expected=RuntimeError) + self.helper_test_exception([(3,4)], lambda x: x.flip((0,0)), expected=RuntimeError) + self.helper_test_exception([(3,4)], lambda x: x.flip((1,1)), expected=RuntimeError) + self.helper_test_exception([(3,4)], lambda x: x.flip((1,-1)), expected=RuntimeError) def test_squeeze(self): helper_test_op([(1,3,6,6)], lambda x: x.squeeze(0)) helper_test_op([(4,3,1,6)], lambda x: x.squeeze(1)) helper_test_op([(4,3,6,6)], lambda x: x.squeeze(3)) - self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, 50), lambda x: x.squeeze(dim=50), expected=IndexError) - self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, -50), lambda x: x.squeeze(dim=-50), expected=IndexError) + self.helper_test_exception([(4,3,6,6)], lambda x: x.squeeze(50), expected=IndexError) + self.helper_test_exception([(4,3,6,6)], lambda x: x.squeeze(50), expected=IndexError) helper_test_op([(4,3,6,1)], lambda x: x.squeeze(-1)) helper_test_op([(4,3,6,6)], lambda x: x.squeeze()) helper_test_op([(1,3,6,6)], lambda x: x.squeeze()) @@ -1907,9 +1906,9 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: x.squeeze(-1)) helper_test_op([()], lambda x: x.squeeze(0)) helper_test_op([()], lambda x: x.squeeze()) - self.helper_test_exception([()], lambda x: torch.squeeze(x, 10), lambda x: x.squeeze(dim=10), expected=IndexError) - self.helper_test_exception([()], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1), expected=IndexError) - self.helper_test_exception([()], lambda x: torch.squeeze(x, -2), lambda x: x.squeeze(dim=-2), expected=IndexError) + self.helper_test_exception([()], lambda x: x.squeeze(10), expected=IndexError) + self.helper_test_exception([()], lambda x: x.squeeze(1), expected=IndexError) + self.helper_test_exception([()], lambda x: x.squeeze(-2), expected=IndexError) def test_unsqueeze(self): helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(0)) @@ -2655,7 +2654,7 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: x.clip(3, 0)) # min > max helper_test_op([(45,65)], lambda x: x.clip(None, 0)) helper_test_op([(45,65)], lambda x: x.clip(0, None)) - self.helper_test_exception([(45,65)], lambda x: x.clip(None, None), lambda x: x.clip(None, None), RuntimeError) + self.helper_test_exception([(45,65)], lambda x: x.clip(None, None), expected=RuntimeError) def test_matvecmat(self): helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z)