conceptually simpler fancy index (#3335)

* init

* add failed case

* fix: temp comment out MULACC cast

* is this right?

* add test case

* oops, forgot to get rid of temp test

* WOOOOOO TOOK OUT 2 TRANSPOSES IN GATHER YAY

* cleaner

* comment cleanup

* update docs

* resolve conflict

* oops

* SUPA FAST

* comment out a test

* del some print statements

* use new broadcast stuff

* more clean up

* move try except

* skip fancy indexing for python backend test_ops
This commit is contained in:
geohotstan
2024-04-09 23:18:04 +08:00
committed by GitHub
parent 980124a605
commit 15f2f39658
2 changed files with 23 additions and 18 deletions

View File

@@ -51,7 +51,7 @@ jobs:
- name: Test dtype with Python emulator
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 test/test_dtype.py
- name: Test ops with Python emulator
run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal)" --durations=20
run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal or test_slice_fancy_indexing_dim_inject_none or test_slice_fancy_indexing_list_indices or test_slice_fancy_indexing_no_dim_collapse or test_slice_fancy_indexing_tuple_indices or test_slice_fancy_indexing_list_with_tensors or test_slice_fancy_indexing_dim_collapse_int)" --durations=20
- name: Test uops with Python emulator
run: PYTHON=1 python3 -m pytest test/test_uops.py --durations=20
- name: Test symbolic with Python emulator

View File

@@ -405,9 +405,10 @@ class Tensor:
# 3. None indexing (no copy)
# - reshape (inject) a dim at the dim where there's None
# 4. Tensor indexing (copy)
# - use Tensor.arange == tensor_index to create a mask
# - apply mask to self by mask * self for dims where index is a tensor
# - (mask * self).sum(dim) to reduce to correct shape
# - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
# - combine masks together with mul
# - apply mask to self by mask * self
# - sum reduce away the extra dims added from creating masks
# Tiny Things:
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
@@ -489,25 +490,29 @@ class Tensor:
# track tensor_dim and tensor_index using a dict
# calc_dim to get dim and use that to normalize the negative tensor indices
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor],tensor_index)}
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)}
# compute sum_dim, arange, and idx
max_idx_dim, first_dim, last_dim = max(i.ndim for i in idx.values()), min(idx.keys()), max(idx.keys())
sum_dim = tuple(d if n==0 else d+max_idx_dim-n for n,d in enumerate(idx.keys()))
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(ret.shape[d], *[1]*(ret.ndim+max_idx_dim-n-sd-1)) \
for n,(sd,d) in enumerate(zip(sum_dim, idx.keys()))]
reshaped_idx = [i.reshape(i.shape + (1,)*(ret.ndim - first_dim - (n or 1))) for n,i in enumerate(idx.values())]
ret = ret.reshape(ret.shape[:first_dim+1] + (1,)*max_idx_dim + ret.shape[first_dim+1:])
masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
# iteratively eq -> mul -> sum fancy index
try:
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
except ValueError as exc: raise IndexError("cannot broadcast indices") from exc
# create masks
for dim, i in idx.items():
try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
except ValueError as exc: raise IndexError("cannot broadcast indices") from exc
a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
masks.append(i == a)
# reduce masks to 1 mask
mask = functools.reduce(lambda x,y: x.mul(y), masks)
# inject 1's for the extra dims added in create masks
sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
# sum reduce the extra dims introduced in create masks
ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()))
# special permute case
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
ret_dims = list(range(ret.ndim))
ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:])
ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
return ret
def __setitem__(self, indices, v:Union[Tensor, ConstType]):