update Torch.gather api (#4692)

* update Torch.gather api

gather(self, dim, index) to match torch

* fix that
This commit is contained in:
chenyu
2024-05-22 21:54:06 -04:00
committed by GitHub
parent 792a494eb8
commit 47aba47f64
5 changed files with 26 additions and 24 deletions

View File

@@ -1251,7 +1251,7 @@ class TestIndexing(unittest.TestCase):
def test_take_along_dim(self):
def _test_against_numpy(t: Tensor, indices: Tensor, dim):
actual = t.gather(indices, dim=dim)
actual = t.gather(dim, indices)
t_np = t.numpy()
indices_np = indices.numpy()
expected = np.take_along_axis(t_np, indices_np, axis=dim)
@@ -1295,24 +1295,24 @@ class TestIndexing(unittest.TestCase):
# dim of `t` and `indices` does not match
with self.assertRaises(RuntimeError, "input and indices should have the same number of dimensions"):
t.gather(indices[0], dim=0)
t.gather(0, indices[0])
# invalid `indices` dtype
with self.assertRaises(RuntimeError):
t.gather(indices.cast(dtypes.bool), dim=0)
t.gather(0, indices.cast(dtypes.bool))
with self.assertRaises(RuntimeError):
t.gather(indices.cast(dtypes.float32), dim=0)
t.gather(0, indices.cast(dtypes.float32))
with self.assertRaises(RuntimeError):
t.gather(indices.cast(dtypes.int32), dim=0)
t.gather(0, indices.cast(dtypes.int32))
# invalid axis
with self.assertRaises(IndexError):
t.gather(indices, dim=-7)
t.gather(-7, indices)
with self.assertRaises(IndexError):
t.gather(t, indices, dim=7)
t.gather(7, indices)
'''
class TestNumpy(unittest.TestCase):

View File

@@ -443,8 +443,8 @@ class TestTypeSpec(unittest.TestCase):
def test_gather_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor([[1, 0], [0, 1]], dtype=data_dtype)
indices = Tensor([[0, 0], [1, 0]], dtype=indices_dtype)
assert X_data.gather(indices, 0).dtype == X_data.dtype
assert X_data.gather(indices, 1).dtype == X_data.dtype
assert X_data.gather(0, indices).dtype == X_data.dtype
assert X_data.gather(1, indices).dtype == X_data.dtype
class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))

View File

@@ -1707,14 +1707,14 @@ class TestOps(unittest.TestCase):
# indices cannot be negative (torch gather)
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0))
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=1), lambda x: x.gather(idx=a, dim=1))
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=2), lambda x: x.gather(idx=a, dim=2))
helper_test_op([(3,4,5)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0))
self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0),
lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0),
lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError))
helper_test_op([(4,5,6)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a))
helper_test_op([(4,5,6)], lambda x: x.gather(dim=1, index=b), lambda x: x.gather(dim=1, index=a))
helper_test_op([(4,5,6)], lambda x: x.gather(dim=2, index=b), lambda x: x.gather(dim=2, index=a))
helper_test_op([(3,4,5)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a))
self.helper_test_exception([(4,5,6)], lambda x: x.gather(dim=0, index=torch.tensor([1], dtype=torch.int64)),
lambda x: x.gather(dim=0, index=Tensor([1], dtype=dtypes.int32)), expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2,1,1)], lambda x: x.gather(dim=0, index=b),
lambda x: x.gather(dim=0, index=a), expected=(RuntimeError, AssertionError))
def test_scaled_product_attention(self):
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention)