simpler idxs_to_idx (#3071)

This commit is contained in:
chenyu
2024-01-10 00:30:10 -05:00
committed by GitHub
parent 2495ca95c7
commit 023f5df0e9

View File

@@ -49,11 +49,9 @@ def simplify(views:Tuple[View, ...]) -> Tuple[View, ...]:
@functools.lru_cache(maxsize=None)
def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node:
assert len(idxs) == len(shape), "need an idx for all dimensions"
acc, ret = 1, []
for tidx,d in zip(reversed(idxs), reversed(shape)):
ret.append(tidx * acc)
acc *= d
return Node.sum(ret)
# idxs[-1] * 1 + idxs[-2] * shape[-1] + idxs[-3] * shape[-1] * shape[-2] + ...
accs = itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1)
return Node.sum([idx * acc for idx, acc in zip(reversed(idxs), accs)])
@dataclass(frozen=True)
class ShapeTracker: