diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index fdc5beec81..407def782e 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -1,11 +1,11 @@ from __future__ import annotations from typing import Optional, Union, Any, Tuple, List -import functools -from tinygrad.helpers import all_same, dedup, round_up, DEBUG +import functools, itertools, operator +from tinygrad.helpers import all_same, dedup, round_up, prod, DEBUG from tinygrad.dtype import DType, Scalar from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps from tinygrad.lazy import LazyBuffer, create_schedule -from tinygrad.shape.shapetracker import ShapeTracker, sint +from tinygrad.shape.shapetracker import sint def all_reduce(op:ReduceOps, lbs): # TODO: replace this with ring reduce @@ -82,10 +82,9 @@ class MultiLazyBuffer: def reshape(self, arg:Tuple[sint, ...]): if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None) - # TODO: this can be wrong - st = ShapeTracker.from_shape(self.shape) - rs = st.real_strides()[self.axis] - new_axis = st.reshape(arg).real_strides().index(rs) + arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) + # new_axis is the one that preserves prod(prior to new_axis) and prod(post to new_axis) + new_axis = [tuple(p) for p in zip(arg_acc, arg_acc[1:])].index((prod(self.shape[:self.axis]), prod(self.shape[:self.axis+1]))) return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs], new_axis) def pad(self, arg:Tuple[Tuple[sint, sint], ...]):