mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update Torch.gather api (#4692)
* update Torch.gather api gather(self, dim, index) to match torch * fix that
This commit is contained in:
@@ -1251,7 +1251,7 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
def test_take_along_dim(self):
|
||||
def _test_against_numpy(t: Tensor, indices: Tensor, dim):
|
||||
actual = t.gather(indices, dim=dim)
|
||||
actual = t.gather(dim, indices)
|
||||
t_np = t.numpy()
|
||||
indices_np = indices.numpy()
|
||||
expected = np.take_along_axis(t_np, indices_np, axis=dim)
|
||||
@@ -1295,24 +1295,24 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
# dim of `t` and `indices` does not match
|
||||
with self.assertRaises(RuntimeError, "input and indices should have the same number of dimensions"):
|
||||
t.gather(indices[0], dim=0)
|
||||
t.gather(0, indices[0])
|
||||
|
||||
# invalid `indices` dtype
|
||||
with self.assertRaises(RuntimeError):
|
||||
t.gather(indices.cast(dtypes.bool), dim=0)
|
||||
t.gather(0, indices.cast(dtypes.bool))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
t.gather(indices.cast(dtypes.float32), dim=0)
|
||||
t.gather(0, indices.cast(dtypes.float32))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
t.gather(indices.cast(dtypes.int32), dim=0)
|
||||
t.gather(0, indices.cast(dtypes.int32))
|
||||
|
||||
# invalid axis
|
||||
with self.assertRaises(IndexError):
|
||||
t.gather(indices, dim=-7)
|
||||
t.gather(-7, indices)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
t.gather(t, indices, dim=7)
|
||||
t.gather(7, indices)
|
||||
'''
|
||||
|
||||
class TestNumpy(unittest.TestCase):
|
||||
|
||||
@@ -443,8 +443,8 @@ class TestTypeSpec(unittest.TestCase):
|
||||
def test_gather_returns_same_dtype(self, data_dtype, indices_dtype):
|
||||
X_data = Tensor([[1, 0], [0, 1]], dtype=data_dtype)
|
||||
indices = Tensor([[0, 0], [1, 0]], dtype=indices_dtype)
|
||||
assert X_data.gather(indices, 0).dtype == X_data.dtype
|
||||
assert X_data.gather(indices, 1).dtype == X_data.dtype
|
||||
assert X_data.gather(0, indices).dtype == X_data.dtype
|
||||
assert X_data.gather(1, indices).dtype == X_data.dtype
|
||||
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
|
||||
@@ -1707,14 +1707,14 @@ class TestOps(unittest.TestCase):
|
||||
# indices cannot be negative (torch gather)
|
||||
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
|
||||
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=1), lambda x: x.gather(idx=a, dim=1))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=2), lambda x: x.gather(idx=a, dim=2))
|
||||
helper_test_op([(3,4,5)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0))
|
||||
self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0),
|
||||
lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0),
|
||||
lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(dim=1, index=b), lambda x: x.gather(dim=1, index=a))
|
||||
helper_test_op([(4,5,6)], lambda x: x.gather(dim=2, index=b), lambda x: x.gather(dim=2, index=a))
|
||||
helper_test_op([(3,4,5)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a))
|
||||
self.helper_test_exception([(4,5,6)], lambda x: x.gather(dim=0, index=torch.tensor([1], dtype=torch.int64)),
|
||||
lambda x: x.gather(dim=0, index=Tensor([1], dtype=dtypes.int32)), expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2,1,1)], lambda x: x.gather(dim=0, index=b),
|
||||
lambda x: x.gather(dim=0, index=a), expected=(RuntimeError, AssertionError))
|
||||
|
||||
def test_scaled_product_attention(self):
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention)
|
||||
|
||||
Reference in New Issue
Block a user