mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simpler idxs_to_idx (#3071)
This commit is contained in:
@@ -49,11 +49,9 @@ def simplify(views:Tuple[View, ...]) -> Tuple[View, ...]:
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node:
|
||||
assert len(idxs) == len(shape), "need an idx for all dimensions"
|
||||
acc, ret = 1, []
|
||||
for tidx,d in zip(reversed(idxs), reversed(shape)):
|
||||
ret.append(tidx * acc)
|
||||
acc *= d
|
||||
return Node.sum(ret)
|
||||
# idxs[-1] * 1 + idxs[-2] * shape[-1] + idxs[-3] * shape[-1] * shape[-2] + ...
|
||||
accs = itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1)
|
||||
return Node.sum([idx * acc for idx, acc in zip(reversed(idxs), accs)])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
|
||||
Reference in New Issue
Block a user