mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
conceptually simpler fancy index (#3335)
* init * add failed case * fix: temp comment out MULACC cast * is this right? * add test case * oops, forgot to get rid of temp test * WOOOOOO TOOK OUT 2 TRANSPOSES IN GATHER YAY * cleaner * comment cleanup * update docs * resolve conflict * oops * SUPA FAST * comment out a test * del some print statements * use new broadcast stuff * more clean up * move try except * skip fancy indexing for python backend test_ops
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
- name: Test dtype with Python emulator
|
||||
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 test/test_dtype.py
|
||||
- name: Test ops with Python emulator
|
||||
run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal)" --durations=20
|
||||
run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal or test_slice_fancy_indexing_dim_inject_none or test_slice_fancy_indexing_list_indices or test_slice_fancy_indexing_no_dim_collapse or test_slice_fancy_indexing_tuple_indices or test_slice_fancy_indexing_list_with_tensors or test_slice_fancy_indexing_dim_collapse_int)" --durations=20
|
||||
- name: Test uops with Python emulator
|
||||
run: PYTHON=1 python3 -m pytest test/test_uops.py --durations=20
|
||||
- name: Test symbolic with Python emulator
|
||||
|
||||
@@ -405,9 +405,10 @@ class Tensor:
|
||||
# 3. None indexing (no copy)
|
||||
# - reshape (inject) a dim at the dim where there's None
|
||||
# 4. Tensor indexing (copy)
|
||||
# - use Tensor.arange == tensor_index to create a mask
|
||||
# - apply mask to self by mask * self for dims where index is a tensor
|
||||
# - (mask * self).sum(dim) to reduce to correct shape
|
||||
# - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
|
||||
# - combine masks together with mul
|
||||
# - apply mask to self by mask * self
|
||||
# - sum reduce away the extra dims added from creating masks
|
||||
# Tiny Things:
|
||||
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
|
||||
# - for any list, List[Union[List, Tuple, int]], must have homogeneous shape
|
||||
@@ -489,25 +490,29 @@ class Tensor:
|
||||
|
||||
# track tensor_dim and tensor_index using a dict
|
||||
# calc_dim to get dim and use that to normalize the negative tensor indices
|
||||
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor],tensor_index)}
|
||||
idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor], tensor_index)}
|
||||
|
||||
# compute sum_dim, arange, and idx
|
||||
max_idx_dim, first_dim, last_dim = max(i.ndim for i in idx.values()), min(idx.keys()), max(idx.keys())
|
||||
sum_dim = tuple(d if n==0 else d+max_idx_dim-n for n,d in enumerate(idx.keys()))
|
||||
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(ret.shape[d], *[1]*(ret.ndim+max_idx_dim-n-sd-1)) \
|
||||
for n,(sd,d) in enumerate(zip(sum_dim, idx.keys()))]
|
||||
reshaped_idx = [i.reshape(i.shape + (1,)*(ret.ndim - first_dim - (n or 1))) for n,i in enumerate(idx.values())]
|
||||
ret = ret.reshape(ret.shape[:first_dim+1] + (1,)*max_idx_dim + ret.shape[first_dim+1:])
|
||||
masks, first_dim, last_dim = [], min(idx.keys()), max(idx.keys())
|
||||
pre_reduce_shape = ret.shape[:first_dim] + (big_shape := broadcast_shape(*(t.shape for t in idx.values()))) + ret.shape[first_dim:]
|
||||
|
||||
# iteratively eq -> mul -> sum fancy index
|
||||
try:
|
||||
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
|
||||
except ValueError as exc: raise IndexError("cannot broadcast indices") from exc
|
||||
# create masks
|
||||
for dim, i in idx.items():
|
||||
try: i = i.reshape(i.shape + (1,)*(ret.ndim - first_dim)).expand(pre_reduce_shape)
|
||||
except ValueError as exc: raise IndexError("cannot broadcast indices") from exc
|
||||
a = Tensor.arange(ret.shape[dim], device=self.device, requires_grad=False).reshape((ret.shape[dim],) + (1,)*(ret.ndim - dim - 1))
|
||||
masks.append(i == a)
|
||||
|
||||
# reduce masks to 1 mask
|
||||
mask = functools.reduce(lambda x,y: x.mul(y), masks)
|
||||
|
||||
# inject 1's for the extra dims added in create masks
|
||||
sh = ret.shape[:first_dim] + (1,) * len(big_shape) + ret.shape[first_dim:]
|
||||
# sum reduce the extra dims introduced in create masks
|
||||
ret = (ret.reshape(sh) * mask).sum(tuple(i + len(big_shape) for i in idx.keys()))
|
||||
|
||||
# special permute case
|
||||
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
|
||||
ret_dims = list(range(ret.ndim))
|
||||
ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:])
|
||||
ret = ret.permute(*range(first_dim, first_dim+len(big_shape)), *range(0, first_dim), *range(first_dim+len(big_shape), ret.ndim))
|
||||
return ret
|
||||
|
||||
def __setitem__(self, indices, v:Union[Tensor, ConstType]):
|
||||
|
||||
Reference in New Issue
Block a user