mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Getitem pin None dimension (#4960)
* fix * remove torch out of bounds test * 1 more test case
This commit is contained in:
@@ -1685,6 +1685,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,None,None,None], lambda x: x[None,None,None,None,None])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,c,d,e], lambda x: x[None,None,j,k,o,p])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,c,None,None], lambda x: x[None,None,j,k,None,None])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,None,None,c,d,e], lambda x: x[i,None,None,k,o,p])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,None,None,c,None,None], lambda x: x[i,None,None,k,None,None])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,None,b,None,d,e], lambda x: x[None,None,j,None,o,p])
|
||||
|
||||
def test_slice_fancy_indexing_dim_inject_and_collapse(self):
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # noqa
|
||||
|
||||
@@ -873,10 +873,6 @@ class Tensor:
|
||||
indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices]
|
||||
else: indices = [indices]
|
||||
|
||||
# handling leading Nones
|
||||
# TODO: unify how None is handled
|
||||
if indices and indices[0] is None: return (self[indices[1:]] if indices[1:] else self).unsqueeze(0)
|
||||
|
||||
# turn scalar Tensors into const val for int indexing if possible
|
||||
indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
|
||||
# move Tensor indices to the same device as self
|
||||
@@ -893,6 +889,7 @@ class Tensor:
|
||||
|
||||
# record None for dimension injection later and filter None and record rest of indices
|
||||
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
||||
tensor_dims = [dim for dim, i in enumerate(indices) if isinstance(i, Tensor)]
|
||||
indices_filtered = [i for i in indices if i is not None]
|
||||
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
||||
|
||||
@@ -941,12 +938,12 @@ class Tensor:
|
||||
if type_dim[Tensor]:
|
||||
# calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim
|
||||
def calc_dim(tensor_dim:int) -> int:
|
||||
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d)
|
||||
return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d)
|
||||
|
||||
assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}"
|
||||
# track tensor_dim and tensor_index using a dict
|
||||
# calc_dim to get dim and use that to normalize the negative tensor indices
|
||||
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)}
|
||||
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(tensor_dims, tensor_index)}
|
||||
|
||||
masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
|
||||
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
|
||||
|
||||
Reference in New Issue
Block a user