mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
Getitem round3 .... (#2760)
* refactor round 3 * comment * oops * oops * oops2 * factored out multiple condition * add a comment for type * wooaah roundup is cool, thanks chenyu lol * add another walrus for symmetry and some spaces * lol wtf useless listcompre
This commit is contained in:
@@ -335,11 +335,9 @@ class Tensor:
|
||||
# use Dict[type, List[dimension]] to track elements in indices
|
||||
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
||||
|
||||
# record None for dimension injection later
|
||||
# record None for dimension injection later and filter None and record rest of indices
|
||||
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
||||
|
||||
# filter None and record rest of indices
|
||||
indices_filtered = tuple(v for v in indices if v is not None)
|
||||
indices_filtered = [v for v in indices if v is not None]
|
||||
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
||||
|
||||
# validation! raise Errors
|
||||
@@ -348,34 +346,35 @@ class Tensor:
|
||||
if float in type_dim: raise IndexError("float type is not valid index")
|
||||
if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0')
|
||||
if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
||||
for dim in type_dim[int]:
|
||||
if indices_filtered[dim] >= self.shape[dim] or indices_filtered[dim] < -self.shape[dim]:
|
||||
raise IndexError(f"index {indices_filtered[dim]} is out of bounds for dimension {dim} with size {self.shape[dim]}")
|
||||
|
||||
# normalize! indices -> start, stop, strides
|
||||
# TODO: this line is completely unreadable
|
||||
start, stop, strides = zip(*y) if (y := [i.indices(sh) if isinstance(i, slice) else slice(normalized:= i if i != -1 else sh-1, normalized+1, 1).indices(sh) if isinstance(i, int) else (0, sh, 1) for i, sh in zip(indices_filtered, self.shape)]) else ((), (), ()) # type: ignore[arg-type] # noqa: E501
|
||||
|
||||
# 2. basic indexing (no copy)
|
||||
# apply slices and flip where strides are negative
|
||||
new_slice = tuple(((0, 0) if e < s else (s, e)) if st > 0 else ((0, 0) if e > s else (e+1, s+1)) for s, e, st in zip(start, stop, strides))
|
||||
sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
|
||||
new_shape = list(sliced_tensor.shape)
|
||||
# currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
|
||||
# turn indices in indices_filtered to Tuple[shrink_arg, strides]
|
||||
for dim in type_dim[int]:
|
||||
if (i := indices_filtered[dim]) >= (sh := self.shape[dim]) or i < -sh:
|
||||
raise IndexError(f"index {i} is out of bounds for dimension {dim} with size {sh}")
|
||||
indices_filtered[dim] = ((i, i+1), 1) if i >= 0 else ((sh+i, sh+i+1), 1)
|
||||
for dim in type_dim[slice]:
|
||||
s, e, st = indices_filtered[dim].indices(self.shape[dim])
|
||||
indices_filtered[dim] = ((0, 0) if (st > 0 and e < s) or (st <= 0 and e > s) else (s, e) if st > 0 else (e+1, s+1), st)
|
||||
for dim in type_dim[Tensor]: indices_filtered[dim] = ((0, self.shape[dim]), 1)
|
||||
|
||||
new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
|
||||
ret = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
|
||||
# add strides by pad -> reshape -> shrink
|
||||
if any(abs(s) != 1 for s in strides):
|
||||
strides = tuple(abs(s) for s in strides)
|
||||
padded_tensor = sliced_tensor.pad(tuple((0, s-(dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape)))
|
||||
reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)))
|
||||
new_shape = list(reshaped_tensor.shape[::2])
|
||||
sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape)))
|
||||
ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
|
||||
ret = ret.reshape(flatten([sh // s, s] for s, sh in zip(strides, ret.shape)))
|
||||
ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
|
||||
|
||||
# inject dim=1 for None and collapse dim for int
|
||||
# inject 1 for dim where it's None and collapse dim for int
|
||||
new_shape = list(ret.shape)
|
||||
for dim in type_dim[None]: new_shape.insert(dim, 1)
|
||||
for dim in (dims_collapsed := [dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])]): new_shape.pop(dim)
|
||||
for dim_sh in new_shape: assert isinstance(dim_sh, int), f"does not support symbolic shape {dim_sh}"
|
||||
|
||||
ret = sliced_tensor.reshape(tuple(new_shape))
|
||||
ret = ret.reshape(tuple(new_shape))
|
||||
|
||||
# 3. advanced indexing (copy)
|
||||
if type_dim[Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user