diff --git a/test/test_ops.py b/test/test_ops.py index bd139ef719..e5b2b72c0b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1685,6 +1685,10 @@ class TestOps(unittest.TestCase): helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,None,None,None], lambda x: x[None,None,None,None,None]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,c,d,e], lambda x: x[None,None,j,k,o,p]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,c,None,None], lambda x: x[None,None,j,k,None,None]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,None,None,c,d,e], lambda x: x[i,None,None,k,o,p]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,None,None,c,None,None], lambda x: x[i,None,None,k,None,None]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,None,d,e], lambda x: x[None,None,j,None,o,p]) def test_slice_fancy_indexing_dim_inject_and_collapse(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # noqa diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ee1bb7c14e..9e69aef254 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -873,10 +873,6 @@ class Tensor: indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices] else: indices = [indices] - # handling leading Nones - # TODO: unify how None is handled - if indices and indices[0] is None: return (self[indices[1:]] if indices[1:] else self).unsqueeze(0) - # turn scalar Tensors into const val for int indexing if possible indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices] # move Tensor indices to the same device as self @@ -893,6 +889,7 @@ class Tensor: # record None for dimension injection later and filter None and record rest of indices type_dim[None] = [dim for dim, i in enumerate(indices) if i is None] + tensor_dims = [dim for dim, i in enumerate(indices) if isinstance(i, Tensor)] indices_filtered = [i for i in indices if i is not None] for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim) @@ -941,12 +938,12 @@ class Tensor: if type_dim[Tensor]: # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim def calc_dim(tensor_dim:int) -> int: - return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d) + return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}" # track tensor_dim and tensor_index using a dict # calc_dim to get dim and use that to normalize the negative tensor indices - idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)} + idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(tensor_dims, tensor_index)} masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys()) pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]