Getitem round3 .... (#2760)

* refactor round 3

* comment

* oops

* oops

* oops2

* factored out multiple condition

* add a comment for type

* wooaah roundup is cool, thanks chenyu lol

* add another walrus for symmetry and some spaces

* lol wtf useless listcompre
This commit is contained in:
geohotstan
2023-12-15 01:22:37 +08:00
committed by GitHub
parent 0ae22b0f81
commit 0398288b79

View File

@@ -335,11 +335,9 @@ class Tensor:
# use Dict[type, List[dimension]] to track elements in indices
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
# record None for dimension injection later
# record None for dimension injection later and filter None and record rest of indices
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
# filter None and record rest of indices
indices_filtered = tuple(v for v in indices if v is not None)
indices_filtered = [v for v in indices if v is not None]
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
# validation! raise Errors
@@ -348,34 +346,35 @@ class Tensor:
if float in type_dim: raise IndexError("float type is not valid index")
if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0')
if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
for dim in type_dim[int]:
if indices_filtered[dim] >= self.shape[dim] or indices_filtered[dim] < -self.shape[dim]:
raise IndexError(f"index {indices_filtered[dim]} is out of bounds for dimension {dim} with size {self.shape[dim]}")
# normalize! indices -> start, stop, strides
# TODO: this line is completely unreadable
start, stop, strides = zip(*y) if (y := [i.indices(sh) if isinstance(i, slice) else slice(normalized:= i if i != -1 else sh-1, normalized+1, 1).indices(sh) if isinstance(i, int) else (0, sh, 1) for i, sh in zip(indices_filtered, self.shape)]) else ((), (), ()) # type: ignore[arg-type] # noqa: E501
# 2. basic indexing (no copy)
# apply slices and flip where strides are negative
new_slice = tuple(((0, 0) if e < s else (s, e)) if st > 0 else ((0, 0) if e > s else (e+1, s+1)) for s, e, st in zip(start, stop, strides))
sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
new_shape = list(sliced_tensor.shape)
# currently indices_filtered: Tuple[Union[slice, int, Tensor], ...]
# turn indices in indices_filtered to Tuple[shrink_arg, strides]
for dim in type_dim[int]:
if (i := indices_filtered[dim]) >= (sh := self.shape[dim]) or i < -sh:
raise IndexError(f"index {i} is out of bounds for dimension {dim} with size {sh}")
indices_filtered[dim] = ((i, i+1), 1) if i >= 0 else ((sh+i, sh+i+1), 1)
for dim in type_dim[slice]:
s, e, st = indices_filtered[dim].indices(self.shape[dim])
indices_filtered[dim] = ((0, 0) if (st > 0 and e < s) or (st <= 0 and e > s) else (s, e) if st > 0 else (e+1, s+1), st)
for dim in type_dim[Tensor]: indices_filtered[dim] = ((0, self.shape[dim]), 1)
new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered)
ret = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
# add strides by pad -> reshape -> shrink
if any(abs(s) != 1 for s in strides):
strides = tuple(abs(s) for s in strides)
padded_tensor = sliced_tensor.pad(tuple((0, s-(dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape)))
reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)))
new_shape = list(reshaped_tensor.shape[::2])
sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape)))
ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape)))
ret = ret.reshape(flatten([sh // s, s] for s, sh in zip(strides, ret.shape)))
ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2])
# inject dim=1 for None and collapse dim for int
# inject 1 for dim where it's None and collapse dim for int
new_shape = list(ret.shape)
for dim in type_dim[None]: new_shape.insert(dim, 1)
for dim in (dims_collapsed := [dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])]): new_shape.pop(dim)
for dim_sh in new_shape: assert isinstance(dim_sh, int), f"does not support symbolic shape {dim_sh}"
ret = sliced_tensor.reshape(tuple(new_shape))
ret = ret.reshape(tuple(new_shape))
# 3. advanced indexing (copy)
if type_dim[Tensor]: