clean up dup lambdas in helper_test_exception (#11325)

This commit is contained in:
chenyu
2025-07-22 12:21:57 -04:00
committed by GitHub
parent c6aa8e58ca
commit 2d7c28de6a

View File

@@ -236,10 +236,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3,3)], lambda x: x.unfold(1, 0, 8))
helper_test_op([(3,3,3,3,3)], lambda x: x.unfold(-1, 2, 2))
self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
self.helper_test_exception([(8,)], lambda x: x.unfold(1, 8, 3), lambda x: x.unfold(1, 8, 3), expected=IndexError)
self.helper_test_exception([(8,)], lambda x: x.unfold(0, -1, 3), lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
self.helper_test_exception([(8,)], lambda x: x.unfold(0, 1, -1), lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
self.helper_test_exception([(8,)], lambda x: x.unfold(1, 8, 3), expected=IndexError)
self.helper_test_exception([(8,)], lambda x: x.unfold(0, 9, 3), expected=RuntimeError)
self.helper_test_exception([(8,)], lambda x: x.unfold(0, 1, -1), expected=RuntimeError)
def test_meshgrid(self):
x, xt = torch.tensor([0.,1.,2.], requires_grad=True), Tensor([0.,1.,2.], requires_grad=True)
@@ -548,7 +548,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (45,65)], lambda x,y: x/y)
helper_test_op([(), ()], lambda x,y: x/y)
@unittest.skipIf(AMD_LLVM, "AMD with LLVM backend generate rcp in FP division causes trunc/floor errors")
@unittest.skipIf(Device.DEFAULT == "AMD" and AMD_LLVM, "AMD with LLVM backend generate rcp in FP division causes trunc/floor errors")
def test_div_rounding_mode(self):
for denominator in [-10, -5, -3, -2, -1, 1, 2, 3, 5, 10]:
# int numerator
@@ -576,8 +576,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x,y: x.div(y, rounding_mode="trunc"), forward_only=True, vals=[[numerator], [denominator]])
helper_test_op(None, lambda x,y: x.div(y, rounding_mode="floor"), forward_only=True, vals=[[numerator], [denominator]])
self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True,
vals=[[5], [0]], expected=RuntimeError)
self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True, vals=[[5], [0]], expected=RuntimeError)
def test_div_int(self):
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5, 6, 7],[1, 2, 3]])
@@ -737,7 +736,7 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True)
helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True)
self.helper_test_exception([(4), (4)], torch.bitwise_xor, Tensor.bitwise_xor, expected=RuntimeError)
self.helper_test_exception([(4), (4)], lambda x,y: x.bitwise_xor(y), expected=RuntimeError)
def test_and(self):
data = [[1,-8,1],[32,1,6]]
@@ -754,7 +753,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: (1 < x) & (x < 2), forward_only=True, vals=[[1.2, 1.2, 1.2, 3.2]])
self.helper_test_exception([(4), (4)], torch.bitwise_and, Tensor.bitwise_and, expected=RuntimeError)
self.helper_test_exception([(4), (4)], lambda x,y: x.bitwise_and(y), expected=RuntimeError)
def test_or(self):
data = [[1,-8,1],[32,1,6]]
@@ -769,7 +768,7 @@ class TestOps(unittest.TestCase):
ten0, ten1 = Tensor(data[0], dtype=dtypes.bool), Tensor(data[1], dtype=dtypes.bool)
helper_test_op([], lambda: tor0|tor1, lambda: ten0|ten1, forward_only=True)
self.helper_test_exception([(4), (4)], torch.bitwise_or, Tensor.bitwise_or, expected=RuntimeError)
self.helper_test_exception([(4), (4)], lambda x,y: x.bitwise_or(y), expected=RuntimeError)
def test_bitwise_not(self):
data = [[1,-8,1],[32,1,6]]
@@ -784,7 +783,7 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: tor.bitwise_not(), lambda: ten.bitwise_not(), forward_only=True)
helper_test_op([], lambda: ~tor, lambda: ~ten, forward_only=True)
self.helper_test_exception([(4)], torch.bitwise_not, Tensor.bitwise_not, expected=RuntimeError)
self.helper_test_exception([(4)], lambda x: x.bitwise_not(), expected=RuntimeError)
def test_lshift(self):
data = [[0,1,2],[1<<8,1<<16,1<<31-1]]
@@ -1131,7 +1130,7 @@ class TestOps(unittest.TestCase):
value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3, largest=False)
np.testing.assert_equal(value.numpy(), [0, 0, 0])
np.testing.assert_equal(indices.numpy(), [2, 4, 6])
self.helper_test_exception([(4)], lambda x: x.topk(5), lambda x: x.topk(5), expected=(RuntimeError, ValueError))
self.helper_test_exception([(4)], lambda x: x.topk(5), expected=(RuntimeError, ValueError))
def test_einsum(self):
# matrix transpose
@@ -1334,9 +1333,9 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x.sum(0))
helper_test_op([()], lambda x: x.sum(-1))
helper_test_op([()], lambda x: x.sum(()))
self.helper_test_exception([(3,4,5,6)], lambda x: x.sum(5), lambda x: x.sum(5), expected=IndexError)
self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError)
self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError)
self.helper_test_exception([(3,4,5,6)], lambda x: x.sum(5), expected=IndexError)
self.helper_test_exception([()], lambda x: x.sum(1), expected=IndexError)
self.helper_test_exception([()], lambda x: x.sum((1,)), expected=IndexError)
def test_sum_dtype_arg(self):
helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(dtype=dtypes.float32))
@@ -1851,9 +1850,9 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.permute((3,2,1,0)))
helper_test_op([(3,4,5,6)], lambda x: x.permute((-2,-1,1,0)))
helper_test_op([()], lambda x: x.permute(()))
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,2)), lambda x: x.permute((0,2)), expected=RuntimeError)
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,1,2,3,3,3)), lambda x: x.permute((0,1,2,3,3,3)), expected=RuntimeError)
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,0,1,2,3)), lambda x: x.permute((0,0,1,2,3)), expected=RuntimeError)
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,2)), expected=RuntimeError)
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,1,2,3,3,3)), expected=RuntimeError)
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,0,1,2,3)), expected=RuntimeError)
def test_reshape(self):
helper_test_op([(4,3,6,6)], lambda x: x.reshape((12,6,6)))
@@ -1864,8 +1863,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(1,)], lambda x: x.reshape(()))
helper_test_op([()], lambda x: x.reshape((1,)))
helper_test_op([()], lambda x: x.reshape((1,1,1)))
self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,2)), lambda x: x.reshape((-1,-1,2)), expected=RuntimeError)
self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,-1,2)), lambda x: x.reshape((-1,-1,-1,2)), expected=RuntimeError)
self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,2)), expected=RuntimeError)
self.helper_test_exception([(3,4)], lambda x: x.reshape((-1,-1,-1,2)), expected=RuntimeError)
with self.assertRaises(ValueError):
x = Tensor.ones((4,3,6,6))
@@ -1890,16 +1889,16 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x.flip(()))
helper_test_op([(1,)], lambda x: x.flip(()))
helper_test_op([(4,3,6,6)], lambda x: x.flip(()))
self.helper_test_exception([(3,4)], lambda x: x.flip((0,0)), lambda x: x.flip((0,0)), expected=RuntimeError)
self.helper_test_exception([(3,4)], lambda x: x.flip((1,1)), lambda x: x.flip((1,1)), expected=RuntimeError)
self.helper_test_exception([(3,4)], lambda x: x.flip((1,-1)), lambda x: x.flip((1,-1)), expected=RuntimeError)
self.helper_test_exception([(3,4)], lambda x: x.flip((0,0)), expected=RuntimeError)
self.helper_test_exception([(3,4)], lambda x: x.flip((1,1)), expected=RuntimeError)
self.helper_test_exception([(3,4)], lambda x: x.flip((1,-1)), expected=RuntimeError)
def test_squeeze(self):
helper_test_op([(1,3,6,6)], lambda x: x.squeeze(0))
helper_test_op([(4,3,1,6)], lambda x: x.squeeze(1))
helper_test_op([(4,3,6,6)], lambda x: x.squeeze(3))
self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, 50), lambda x: x.squeeze(dim=50), expected=IndexError)
self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, -50), lambda x: x.squeeze(dim=-50), expected=IndexError)
self.helper_test_exception([(4,3,6,6)], lambda x: x.squeeze(50), expected=IndexError)
self.helper_test_exception([(4,3,6,6)], lambda x: x.squeeze(50), expected=IndexError)
helper_test_op([(4,3,6,1)], lambda x: x.squeeze(-1))
helper_test_op([(4,3,6,6)], lambda x: x.squeeze())
helper_test_op([(1,3,6,6)], lambda x: x.squeeze())
@@ -1907,9 +1906,9 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x.squeeze(-1))
helper_test_op([()], lambda x: x.squeeze(0))
helper_test_op([()], lambda x: x.squeeze())
self.helper_test_exception([()], lambda x: torch.squeeze(x, 10), lambda x: x.squeeze(dim=10), expected=IndexError)
self.helper_test_exception([()], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1), expected=IndexError)
self.helper_test_exception([()], lambda x: torch.squeeze(x, -2), lambda x: x.squeeze(dim=-2), expected=IndexError)
self.helper_test_exception([()], lambda x: x.squeeze(10), expected=IndexError)
self.helper_test_exception([()], lambda x: x.squeeze(1), expected=IndexError)
self.helper_test_exception([()], lambda x: x.squeeze(-2), expected=IndexError)
def test_unsqueeze(self):
helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(0))
@@ -2655,7 +2654,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.clip(3, 0)) # min > max
helper_test_op([(45,65)], lambda x: x.clip(None, 0))
helper_test_op([(45,65)], lambda x: x.clip(0, None))
self.helper_test_exception([(45,65)], lambda x: x.clip(None, None), lambda x: x.clip(None, None), RuntimeError)
self.helper_test_exception([(45,65)], lambda x: x.clip(None, None), expected=RuntimeError)
def test_matvecmat(self):
helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z)