MultiLazyBuffer.reshape new_axis without real_strides (#3272)

similar to contraction, but this is one is for finding the mapped single axis
This commit is contained in:
chenyu
2024-01-28 23:53:52 -05:00
committed by GitHub
parent 34c7621556
commit af4ca85594

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
from typing import Optional, Union, Any, Tuple, List
import functools
from tinygrad.helpers import all_same, dedup, round_up, DEBUG
import functools, itertools, operator
from tinygrad.helpers import all_same, dedup, round_up, prod, DEBUG
from tinygrad.dtype import DType, Scalar
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
from tinygrad.lazy import LazyBuffer, create_schedule
from tinygrad.shape.shapetracker import ShapeTracker, sint
from tinygrad.shape.shapetracker import sint
def all_reduce(op:ReduceOps, lbs):
# TODO: replace this with ring reduce
@@ -82,10 +82,9 @@ class MultiLazyBuffer:
def reshape(self, arg:Tuple[sint, ...]):
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None)
# TODO: this can be wrong
st = ShapeTracker.from_shape(self.shape)
rs = st.real_strides()[self.axis]
new_axis = st.reshape(arg).real_strides().index(rs)
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
# new_axis is the one that preserves prod(prior to new_axis) and prod(post to new_axis)
new_axis = [tuple(p) for p in zip(arg_acc, arg_acc[1:])].index((prod(self.shape[:self.axis]), prod(self.shape[:self.axis+1])))
return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs], new_axis)
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):