fix Tensor.meshgrid for 1D input and check indexing (#7740)

This commit is contained in:
chenyu
2024-11-16 23:39:30 -05:00
committed by GitHub
parent 72a41095bc
commit a15a900415
2 changed files with 15 additions and 4 deletions

View File

@@ -221,14 +221,23 @@ class TestOps(unittest.TestCase):
y, yt = torch.tensor([3.,4.,5.,6.], requires_grad=True), Tensor([3.,4.,5.,6.], requires_grad=True)
z, zt = torch.tensor([7.,8.,9.], requires_grad=True), Tensor([7.,8.,9.], requires_grad=True)
for indexing in ("ij", "xy"):
tor = torch.meshgrid(x, indexing=indexing)
ten = xt.meshgrid(indexing=indexing)
self.assertEqual(len(tor), len(ten))
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
tor = torch.meshgrid(x, y, indexing=indexing)
ten = xt.meshgrid(yt, indexing=indexing)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i])
self.assertEqual(len(tor), len(ten))
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
tor = torch.meshgrid(x, torch.tensor(10., requires_grad=True), y, z, indexing=indexing)
ten = xt.meshgrid(Tensor(10., requires_grad=True), yt, zt, indexing=indexing)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i])
self.assertEqual(len(tor), len(ten))
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=RuntimeError)
def test_arange(self):
helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True)