From 15f2f396580c2317af936f6abfa931c50e40795e Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Tue, 9 Apr 2024 23:18:04 +0800 Subject: [PATCH] conceptually simpler fancy index (#3335) * init * add failed case * fix: temp comment out MULACC cast * is this right? * add test case * oops, forgot to get rid of temp test * WOOOOOO TOOK OUT 2 TRANSPOSES IN GATHER YAY * cleaner * comment cleanup * update docs * resolve conflict * oops * SUPA FAST * comment out a test * del some print statements * use new broadcast stuff * more clean up * move try except * skip fancy indexing for python backend test_ops --- .github/workflows/test.yml | 2 +- tinygrad/tensor.py | 39 +++++++++++++++++++++----------------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 081523006a..c1fa2cb1a2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -51,7 +51,7 @@ jobs: - name: Test dtype with Python emulator run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 test/test_dtype.py - name: Test ops with Python emulator - run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal)" --durations=20 + run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal or test_slice_fancy_indexing_dim_inject_none or test_slice_fancy_indexing_list_indices or test_slice_fancy_indexing_no_dim_collapse or test_slice_fancy_indexing_tuple_indices or test_slice_fancy_indexing_list_with_tensors or test_slice_fancy_indexing_dim_collapse_int)" --durations=20 - name: Test uops with Python emulator run: PYTHON=1 python3 -m pytest test/test_uops.py --durations=20 - name: Test symbolic with Python emulator diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2541b779ce..d29bfd7cf1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -405,9 +405,10 @@ class Tensor: # 3. None indexing (no copy) # - reshape (inject) a dim at the dim where there's None # 4. Tensor indexing (copy) - # - use Tensor.arange == tensor_index to create a mask - # - apply mask to self by mask * self for dims where index is a tensor - # - (mask * self).sum(dim) to reduce to correct shape + # - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask) + # - combine masks together with mul + # - apply mask to self by mask * self + # - sum reduce away the extra dims added from creating masks # Tiny Things: # 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis] # - for any list, List[Union[List, Tuple, int]], must have homogeneous shape @@ -489,25 +490,29 @@ class Tensor: # track tensor_dim and tensor_index using a dict # calc_dim to get dim and use that to normalize the negative tensor indices - idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor],tensor_index)} + idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)} - # compute sum_dim, arange, and idx - max_idx_dim, first_dim, last_dim = max(i.ndim for i in idx.values()), min(idx.keys()), max(idx.keys()) - sum_dim = tuple(d if n==0 else d+max_idx_dim-n for n,d in enumerate(idx.keys())) - arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(ret.shape[d], *[1]*(ret.ndim+max_idx_dim-n-sd-1)) \ - for n,(sd,d) in enumerate(zip(sum_dim, idx.keys()))] - reshaped_idx = [i.reshape(i.shape + (1,)*(ret.ndim - first_dim - (n or 1))) for n,i in enumerate(idx.values())] - ret = ret.reshape(ret.shape[:first_dim+1] + (1,)*max_idx_dim + ret.shape[first_dim+1:]) + masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys()) + pre_reduce_shape = ret.shape[:first_dim] + (big_shape := broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:] - # iteratively eq -> mul -> sum fancy index - try: - for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd) - except ValueError as exc: raise IndexError("cannot broadcast indices") from exc + # create masks + for dim, i in idx.items(): + try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape) + except ValueError as exc: raise IndexError("cannot broadcast indices") from exc + a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1)) + masks.append(i == a) + + # reduce masks to 1 mask + mask = functools.reduce(lambda x,y: x.mul(y), masks) + + # inject 1's for the extra dims added in create masks + sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:] + # sum reduce the extra dims introduced in create masks + ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys())) # special permute case if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)): - ret_dims = list(range(ret.ndim)) - ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:]) + ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim)) return ret def __setitem__(self, indices, v:Union[Tensor, ConstType]):