diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index 8f455cd9e8..7ccb8f5366 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -1320,8 +1320,9 @@ class TestNumpy(unittest.TestCase): self.assertRaises(IndexError, lambda: a[0, 0, -1.4]) self.assertRaises(IndexError, lambda: a[-1.4, 0, 0]) self.assertRaises(IndexError, lambda: a[0, -1.4, 0]) - self.assertRaises(IndexError, lambda: a[0.0:, 0.0]) - self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:]) + # these two trigger slice internal type verification first + self.assertRaises(TypeError, lambda: a[0.0:, 0.0]) + self.assertRaises(TypeError, lambda: a[0.0:, 0.0,:]) def test_none_index(self): # `None` index adds newaxis diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 711c095f24..30a1f5f60c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,9 +2,7 @@ from __future__ import annotations import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib from contextlib import ContextDecorator -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal, TYPE_CHECKING -from collections import defaultdict - +from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN @@ -1099,119 +1097,95 @@ class Tensor(SimpleMathTrait): # 3. Out of bounds Tensor indexing results in 0 # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor: - # 1. indices normalization and validation - # treat internal tuples and lists as Tensors and standardize indices to list type - if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)] - elif isinstance(indices, (tuple, list)): - indices = [Tensor(i, self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices] - else: indices = [indices] - + # wrap single index into a list + if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices] # turn scalar Tensors into const val for int indexing if possible - indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices] - # move Tensor indices to the same device as self - indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices] + x, indices = self, [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices] # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None) - ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis] + if len(ellipsis_idx := [dim for dim, i in enumerate(indices) if i is Ellipsis]) > 1: raise IndexError("indices can only have a single ellipsis") fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices) num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None) + if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}") indices[fill_idx:fill_idx+1] = [slice(None)] * (self.ndim - num_indices) - # use Dict[type, List[dimension]] to track elements in indices - type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list) + indices_parsed, dim = [], 0 + for index in indices: + size = 1 if index is None else self.shape[dim] + boundary, stride = [0, size], 1 # defaults + match index: + case list() | tuple() | Tensor(): + if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False) + if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported") + index = (index < 0).where(size, 0) + index # treat negative index values + case int() | UOp(): # sint + if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}") + boundary = [index, index+1] if index >= 0 else [index+size, index+size+1] + case slice(): + if index.step == 0: raise ValueError(f"{index=} cannot have 0 as step") + if not all(isinstance(s,int) or s is None for s in (index.start,index.stop,index.step)): raise TypeError("only int slicing is supported") + # handle int slicing + *boundary, stride = index.indices(cast(SupportsIndex, size)) + if stride * (boundary[1] - boundary[0]) < 0: boundary = [0, 0] + elif stride < 0: boundary = [boundary[1] + 1, boundary[0] + 1] + # update size for slice + size = ceildiv((boundary[1] - boundary[0]), abs(stride)) + case None: pass # do nothing + case _: raise IndexError(f"{type(index).__name__} indexing is not supported") + indices_parsed.append({"index":index, "size":size, "boundary":tuple(boundary), "stride":stride}) + if index is not None: dim += 1 - # record None for dimension injection later and filter None and record rest of indices - type_dim[None] = [dim for dim, i in enumerate(indices) if i is None] - indices_filtered = [i for i in indices if i is not None] - for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim) + # movement op indexing + if mops := [i for i in indices_parsed if i['index'] is not None]: + # flip negative strides + shrinks, strides = zip(*((i['boundary'], i['stride']) for i in mops)) + x = x.shrink(shrinks).flip(tuple(i for i,st in enumerate(strides) if st < 0)) + # handle stride != 1 or -1 + if any(abs(st) != 1 for st in strides): + strides = tuple(abs(s) for s in strides) + # pad shape to multiple of stride + if not all_int(x.shape): raise RuntimeError("symbolic shape not supprted") + x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides))) + x = x.reshape(tuple(flatten((s // st, st) for s, st in zip(x.shape, strides)))) + x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2]) - if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')") - for index_type in type_dim: - if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported") - if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}") + # dim injection from None by including None dim size (which is 1) and dim collapse by skipping int dim size + x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], int))) - # 2. basic indexing, uses only movement ops (no copy) - # currently indices_filtered: Tuple[Union[int, slice, Tensor], ...] - # turn indices in indices_filtered to Tuple[new_slice, strides] - for dim in type_dim[int]: - if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size: - raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}") - indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1) - for dim in type_dim[slice]: - if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step") - if not all(isinstance(x, (int, type(None))) for x in (index.start, index.stop, index.step)): - raise TypeError(f"Unsupported slice for dimension {dim}. Expected slice with integers or None, got slice(" - f"{', '.join(type(x).__name__ for x in (index.start, index.stop, index.step))}).") - s, e, st = index.indices(self.shape[dim]) - indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st) - # skip all Tensor dims for basic indexing - for dim in type_dim[Tensor]: - dtype = indices_filtered[dim].dtype - if not dtypes.is_int(dtype): raise IndexError(f"{dtype=} on {dim=} is not supported, only int tensor indexing is supported") - indices_filtered[dim] = ((0, self.shape[dim]), 1) + # tensor indexing + if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]: + # unload the tensor object into actual tensors + dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), [] + pre_reduce_shape = x.shape[:dims[0]] + (big_shape := _broadcast_shape(*(t.shape for t in tensors))) + x.shape[dims[0]:] - new_slice, strides = ((), ()) if not indices_filtered else zip(*indices_filtered) - # flip negative strides - ret = self.shrink(new_slice).flip(tuple(i for i, st in enumerate(strides) if st < 0)) - # handle stride != 1 or -1 - if any(abs(st) != 1 for st in strides): - strides = tuple(abs(s) for s in strides) - # pad shape to multiple of stride - if not all_int(ret.shape): raise RuntimeError("symbolic shape not supprted") - ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides))) - ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides)))) - ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2]) - - # inject 1 for dim where it's None and collapse dim for int - new_shape = list(ret.shape) - for dim in type_dim[None]: new_shape.insert(dim, 1) - for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim) - - ret = ret.reshape(new_shape) - - # 3. advanced indexing (copy) - if type_dim[Tensor]: - dim_tensors = [(dim, i) for dim, i in enumerate(indices) if isinstance(i, Tensor)] - # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim - def calc_dim(tensor_dim:int) -> int: - return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) - - assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}" - # track tensor_dim and tensor_index using a dict - # calc_dim to get dim and use that to normalize the negative tensor indices - idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in dim_tensors} - - masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys()) - pre_reduce_shape = ret.shape[:first_dim] + (big_shape := _broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:] - - # create masks - for dim, i in idx.items(): - try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape) + # create index masks + for dim, tensor in zip(dims, tensors): + try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape) except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e - a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1)) + a = Tensor.arange(x.shape[dim], device=self.device, requires_grad=False).reshape((x.shape[dim],) + (1,)*(x.ndim - dim - 1)) masks.append(i == a) # reduce masks to 1 mask mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks) # inject 1's for the extra dims added in create masks - reshape_arg = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:] + reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:] # sum reduce the extra dims introduced in create masks - ret = (ret.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(i + len(big_shape) for i in idx.keys()), acc_dtype=ret.dtype) + x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), acc_dtype=x.dtype) # special permute case - if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)): - ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim)) + if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)): + x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim)) # for advanced setitem, returns whole tensor with indices replaced if v is not None: - vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(ret.shape, v.shape)) + vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape)) # add back reduced dims from sum for dim in sum_axis: vb = vb.unsqueeze(dim) # run _masked_setitem on tuple of axis that is to be reduced to match self.shape - ret = _masked_setitem(self, vb, mask, tuple(range(first_dim, first_dim + len(big_shape)))) + x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape)))) - return ret + return x def __getitem__(self, indices) -> Tensor: return self._getitem(indices)