smaller test_ops inputs (#12007)

This commit is contained in:
chenyu
2025-09-04 16:22:33 -04:00
committed by GitHub
parent dc8501af30
commit 52166fd7eb

View File

@@ -1128,12 +1128,12 @@ class TestOps(unittest.TestCase):
lambda x: x.argsort(dim, descending), forward_only=True)
def test_topk(self):
helper_test_op([(10)], lambda x: x.topk(3).values, lambda x: x.topk(3)[0], forward_only=True)
helper_test_op([(10)], lambda x: x.topk(3).indices.type(torch.int32), lambda x: x.topk(3)[1], forward_only=True)
helper_test_op([(8)], lambda x: x.topk(3).values, lambda x: x.topk(3)[0], forward_only=True)
helper_test_op([(8)], lambda x: x.topk(3).indices.type(torch.int32), lambda x: x.topk(3)[1], forward_only=True)
for dim in [0, 1, -1]:
for largest in [True, False]:
for sorted_ in [True]: # TODO support False
helper_test_op([(6,5,4)],
helper_test_op([(5,5,4)],
lambda x: x.topk(4, dim, largest, sorted_).values,
lambda x: x.topk(4, dim, largest, sorted_)[0], forward_only=True)
helper_test_op([(5,5,4)],
@@ -1150,47 +1150,47 @@ class TestOps(unittest.TestCase):
def test_einsum(self):
# matrix transpose
helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
helper_test_op([(150,150)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a))
helper_test_op([(150,150)], lambda a: torch.einsum('ji', a), lambda a: Tensor.einsum('ji', a))
helper_test_op([(20,30,40)], lambda a: torch.einsum('jki', a), lambda a: Tensor.einsum('jki', a))
helper_test_op([(20,30,40)], lambda a: torch.einsum('dog', a), lambda a: Tensor.einsum('dog', a))
helper_test_op([(10,10)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
helper_test_op([(10,10)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a))
helper_test_op([(10,10)], lambda a: torch.einsum('ji', a), lambda a: Tensor.einsum('ji', a))
helper_test_op([(4,6,8)], lambda a: torch.einsum('jki', a), lambda a: Tensor.einsum('jki', a))
helper_test_op([(4,6,8)], lambda a: torch.einsum('dog', a), lambda a: Tensor.einsum('dog', a))
# no -> and empty rhs
helper_test_op([(20,30),(30,40)], lambda a, b: torch.einsum('ij,jk', a, b), lambda a, b: Tensor.einsum('ij,jk', a, b))
helper_test_op([(4,6),(6,8)], lambda a, b: torch.einsum('ij,jk', a, b), lambda a, b: Tensor.einsum('ij,jk', a, b))
# sum all elements
helper_test_op([(20,30,40)], lambda a: torch.einsum('ijk->', a), lambda a: Tensor.einsum('ijk->', a))
helper_test_op([(4,6,8)], lambda a: torch.einsum('ijk->', a), lambda a: Tensor.einsum('ijk->', a))
# column sum
helper_test_op([(50,50)], lambda a: torch.einsum('ij->j', a), lambda a: Tensor.einsum('ij->j', a))
helper_test_op([(5,5)], lambda a: torch.einsum('ij->j', a), lambda a: Tensor.einsum('ij->j', a))
# row sum
helper_test_op([(15,15)], lambda a: torch.einsum('ij->i', a), lambda a: Tensor.einsum('ij->i', a))
helper_test_op([(5,5)], lambda a: torch.einsum('ij->i', a), lambda a: Tensor.einsum('ij->i', a))
# matrix-vector multiplication
helper_test_op([(15,20), (20,)], lambda a,b: torch.einsum('ik,k->i', a,b), lambda a,b: Tensor.einsum('ik,k->i', a, b))
helper_test_op([(3,4), (4,)], lambda a,b: torch.einsum('ik,k->i', a,b), lambda a,b: Tensor.einsum('ik,k->i', a, b))
# matrix-matrix multiplication
helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('ik,kj->ij', a,b), lambda a,b: Tensor.einsum('ik,kj->ij', a, b))
helper_test_op([(3,4), (4,5)], lambda a,b: torch.einsum('ik,kj->ij', a,b), lambda a,b: Tensor.einsum('ik,kj->ij', a, b))
# matrix-matrix multiplication, different letter order
helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('jk,ki->ji', a,b), lambda a,b: Tensor.einsum('jk,ki->ji', a, b))
helper_test_op([(3,4), (4,5)], lambda a,b: torch.einsum('jk,ki->ji', a,b), lambda a,b: Tensor.einsum('jk,ki->ji', a, b))
# dot product
helper_test_op([(30),(30)], lambda a,b: torch.einsum('i,i->i', [a,b]), lambda a,b: Tensor.einsum('i,i->i', [a,b]))
helper_test_op([(5),(5)], lambda a,b: torch.einsum('i,i->i', [a,b]), lambda a,b: Tensor.einsum('i,i->i', [a,b]))
# hadamard product
helper_test_op([(30,40),(30,40)], lambda a,b: torch.einsum('ij,ij->ij', a,b), lambda a,b: Tensor.einsum('ij,ij->ij', a,b))
helper_test_op([(5,6),(5,6)], lambda a,b: torch.einsum('ij,ij->ij', a,b), lambda a,b: Tensor.einsum('ij,ij->ij', a,b))
# outer product
helper_test_op([(15,), (15,)], lambda a,b: torch.einsum('i,j->ij', a,b), lambda a,b: Tensor.einsum('i,j->ij',a,b))
helper_test_op([(5,), (5,)], lambda a,b: torch.einsum('i,j->ij', a,b), lambda a,b: Tensor.einsum('i,j->ij',a,b))
# batch matrix multiplication
helper_test_op([(10,20,30),(10,30,40)], lambda a,b: torch.einsum('ijk,ikl->ijl', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->ijl', [a, b]))
helper_test_op([(2,4,6),(2,6,8)], lambda a,b: torch.einsum('ijk,ikl->ijl', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->ijl', [a, b]))
# batch matrix multiplication, result permuted
helper_test_op([(10,20,25),(10,25,32)], lambda a,b: torch.einsum('ijk,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->jil', [a, b]))
helper_test_op([(2,4,5),(2,5,7)], lambda a,b: torch.einsum('ijk,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->jil', [a, b]))
# batch matrix multiplication, result & input permuted
helper_test_op([(20,10,25),(10,25,32)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b]))
helper_test_op([(4,2,5),(2,5,7)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b]))
# batch matrix multiplication, result with different letters
helper_test_op([(10,20,30),(10,30,40)], lambda a,b: torch.einsum('ijk,ika->ija', [a, b]), lambda a,b: Tensor.einsum('ijk,ika->ija', [a, b]))
helper_test_op([(2,4,6),(2,6,8)], lambda a,b: torch.einsum('ijk,ika->ija', [a, b]), lambda a,b: Tensor.einsum('ijk,ika->ija', [a, b]))
# tensor contraction
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b),
helper_test_op([(3,5,8,10),(11,7,5,13,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b),
lambda a,b: Tensor.einsum('pqrs,tuqvr->pstuv', a,b), atol=1e-5)
# tensor contraction, input permuted
helper_test_op([(3,8,10,5),(11,5,13,16,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b),
helper_test_op([(3,8,10,5),(11,5,7,13,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b),
lambda a,b: Tensor.einsum('prsq,tquvr->pstuv', a,b), atol=1e-5)
# tensor contraction, result with different letters
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('zqrs,tuqvr->zstuv', a,b),
helper_test_op([(3,5,8,10),(11,7,5,13,8)], lambda a,b: torch.einsum('zqrs,tuqvr->zstuv', a,b),
lambda a,b: Tensor.einsum('zqrs,tuqvr->zstuv', a,b), atol=1e-5)
# bilinear transformation
helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c]))
@@ -2340,37 +2340,36 @@ class TestOps(unittest.TestCase):
for ksz in [(2,2), (3,3), 2, 3, (3,2)]:
for p in [1, (1,0), (0,1)]:
with self.subTest(kernel_size=ksz, padding=p):
helper_test_op([(32,2,11,28)],
helper_test_op([(4,2,11,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=p),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=p))
self.helper_test_exception([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), padding=(1,1,1)),
self.helper_test_exception([(4,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), padding=(1,1,1)),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), padding=(1,1,1)), expected=(RuntimeError, ValueError))
def test_max_pool2d_asymmetric_padding(self):
shape = (32,2,111,28)
for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]:
with self.subTest(padding=p):
helper_test_op([shape],
helper_test_op([(4,2,111,28)],
lambda x: torch.nn.functional.max_pool2d(torch.nn.functional.pad(x, p, value=float("-inf")), kernel_size=(5,5)),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), padding=p))
def test_max_pool2d_padding_int(self):
ksz = (2,2)
helper_test_op([(32,2,11,28)],
helper_test_op([(4,2,11,28)],
lambda x: torch.nn.functional.max_pool2d(x.int(), kernel_size=ksz, padding=1),
lambda x: Tensor.max_pool2d(x.int(), kernel_size=ksz, padding=1), forward_only=True)
def test_max_pool2d_bigger_stride(self):
for stride in [(2,3), (3,2), 2, 3]:
with self.subTest(stride=stride):
helper_test_op([(32,2,11,28)],
helper_test_op([(4,2,11,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride))
def test_max_pool2d_bigger_stride_dilation(self):
for stride, dilation in zip([(2,3), (3,2), 2, 3, 4], [(3,2), (2,3), 2, 3, 6]):
with self.subTest(stride=stride):
helper_test_op([(32,2,11,28)],
helper_test_op([(4,2,11,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation))