fix Tensor.meshgrid for 1D input and check indexing (#7740)

This commit is contained in:
chenyu
2024-11-16 23:39:30 -05:00
committed by GitHub
parent 72a41095bc
commit a15a900415
2 changed files with 15 additions and 4 deletions

View File

@@ -221,14 +221,23 @@ class TestOps(unittest.TestCase):
y, yt = torch.tensor([3.,4.,5.,6.], requires_grad=True), Tensor([3.,4.,5.,6.], requires_grad=True)
z, zt = torch.tensor([7.,8.,9.], requires_grad=True), Tensor([7.,8.,9.], requires_grad=True)
for indexing in ("ij", "xy"):
tor = torch.meshgrid(x, indexing=indexing)
ten = xt.meshgrid(indexing=indexing)
self.assertEqual(len(tor), len(ten))
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
tor = torch.meshgrid(x, y, indexing=indexing)
ten = xt.meshgrid(yt, indexing=indexing)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i])
self.assertEqual(len(tor), len(ten))
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
tor = torch.meshgrid(x, torch.tensor(10., requires_grad=True), y, z, indexing=indexing)
ten = xt.meshgrid(Tensor(10., requires_grad=True), yt, zt, indexing=indexing)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i])
self.assertEqual(len(tor), len(ten))
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=RuntimeError)
def test_arange(self):
helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True)

View File

@@ -1377,6 +1377,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(grid_y.numpy())
```
"""
if indexing not in ("ij", "xy"): raise RuntimeError(f'indexing must be in ("ij", "xy"), got {indexing}')
if not args: return (self,)
tensors = (self,) + args if indexing == "ij" else (args[0],) + (self,) + args[1:]
tensors = tuple(t.reshape((-1,) + (1,)*(len(args) - i)) for i,t in enumerate(tensors))
tensors, out_shape = (tensors if indexing == "ij" else (tensors[1],) + (tensors[0],) + tensors[2:]), _broadcast_shape(*(t.shape for t in tensors))