mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix Tensor.meshgrid for 1D input and check indexing (#7740)
This commit is contained in:
@@ -221,14 +221,23 @@ class TestOps(unittest.TestCase):
|
||||
y, yt = torch.tensor([3.,4.,5.,6.], requires_grad=True), Tensor([3.,4.,5.,6.], requires_grad=True)
|
||||
z, zt = torch.tensor([7.,8.,9.], requires_grad=True), Tensor([7.,8.,9.], requires_grad=True)
|
||||
for indexing in ("ij", "xy"):
|
||||
tor = torch.meshgrid(x, indexing=indexing)
|
||||
ten = xt.meshgrid(indexing=indexing)
|
||||
self.assertEqual(len(tor), len(ten))
|
||||
for tor_i, ten_i in zip(tor, ten):
|
||||
helper_test_op([], lambda: tor_i, lambda: ten_i)
|
||||
tor = torch.meshgrid(x, y, indexing=indexing)
|
||||
ten = xt.meshgrid(yt, indexing=indexing)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i])
|
||||
self.assertEqual(len(tor), len(ten))
|
||||
for tor_i, ten_i in zip(tor, ten):
|
||||
helper_test_op([], lambda: tor_i, lambda: ten_i)
|
||||
tor = torch.meshgrid(x, torch.tensor(10., requires_grad=True), y, z, indexing=indexing)
|
||||
ten = xt.meshgrid(Tensor(10., requires_grad=True), yt, zt, indexing=indexing)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i])
|
||||
self.assertEqual(len(tor), len(ten))
|
||||
for tor_i, ten_i in zip(tor, ten):
|
||||
helper_test_op([], lambda: tor_i, lambda: ten_i)
|
||||
|
||||
self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=RuntimeError)
|
||||
|
||||
def test_arange(self):
|
||||
helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True)
|
||||
|
||||
@@ -1377,6 +1377,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
print(grid_y.numpy())
|
||||
```
|
||||
"""
|
||||
if indexing not in ("ij", "xy"): raise RuntimeError(f'indexing must be in ("ij", "xy"), got {indexing}')
|
||||
if not args: return (self,)
|
||||
tensors = (self,) + args if indexing == "ij" else (args[0],) + (self,) + args[1:]
|
||||
tensors = tuple(t.reshape((-1,) + (1,)*(len(args) - i)) for i,t in enumerate(tensors))
|
||||
tensors, out_shape = (tensors if indexing == "ij" else (tensors[1],) + (tensors[0],) + tensors[2:]), _broadcast_shape(*(t.shape for t in tensors))
|
||||
|
||||
Reference in New Issue
Block a user