Getitem pin None dimension (#4960)

* fix

* remove torch out of bounds test

* 1 more test case
This commit is contained in:
geohotstan
2024-06-14 22:48:59 +08:00
committed by GitHub
parent 2eeddf1a46
commit 90332eb529
2 changed files with 7 additions and 6 deletions

View File

@@ -1685,6 +1685,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,None,None,None], lambda x: x[None,None,None,None,None])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,c,d,e], lambda x: x[None,None,j,k,o,p])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,c,None,None], lambda x: x[None,None,j,k,None,None])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,None,None,c,d,e], lambda x: x[i,None,None,k,o,p])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,None,None,c,None,None], lambda x: x[i,None,None,k,None,None])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,None,d,e], lambda x: x[None,None,j,None,o,p])
def test_slice_fancy_indexing_dim_inject_and_collapse(self):
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # noqa

View File

@@ -873,10 +873,6 @@ class Tensor:
indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
else: indices = [indices]
# handling leading Nones
# TODO: unify how None is handled
if indices and indices[0] is None: return (self[indices[1:]] if indices[1:] else self).unsqueeze(0)
# turn scalar Tensors into const val for int indexing if possible
indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
# move Tensor indices to the same device as self
@@ -893,6 +889,7 @@ class Tensor:
# record None for dimension injection later and filter None and record rest of indices
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
tensor_dims = [dim for dim, i in enumerate(indices) if isinstance(i, Tensor)]
indices_filtered = [i for i in indices if i is not None]
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
@@ -941,12 +938,12 @@ class Tensor:
if type_dim[Tensor]:
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
def calc_dim(tensor_dim:int) -> int:
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d)
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
# track tensor_dim and tensor_index using a dict
# calc_dim to get dim and use that to normalize the negative tensor indices
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)}
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(tensor_dims, tensor_index)}
masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]