mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
MultiLazyBuffer.reshape new_axis without real_strides (#3272)
similar to contraction, but this is one is for finding the mapped single axis
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Union, Any, Tuple, List
|
||||
import functools
|
||||
from tinygrad.helpers import all_same, dedup, round_up, DEBUG
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, dedup, round_up, prod, DEBUG
|
||||
from tinygrad.dtype import DType, Scalar
|
||||
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer, create_schedule
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, sint
|
||||
from tinygrad.shape.shapetracker import sint
|
||||
|
||||
def all_reduce(op:ReduceOps, lbs):
|
||||
# TODO: replace this with ring reduce
|
||||
@@ -82,10 +82,9 @@ class MultiLazyBuffer:
|
||||
|
||||
def reshape(self, arg:Tuple[sint, ...]):
|
||||
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None)
|
||||
# TODO: this can be wrong
|
||||
st = ShapeTracker.from_shape(self.shape)
|
||||
rs = st.real_strides()[self.axis]
|
||||
new_axis = st.reshape(arg).real_strides().index(rs)
|
||||
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
||||
# new_axis is the one that preserves prod(prior to new_axis) and prod(post to new_axis)
|
||||
new_axis = [tuple(p) for p in zip(arg_acc, arg_acc[1:])].index((prod(self.shape[:self.axis]), prod(self.shape[:self.axis+1])))
|
||||
return MultiLazyBuffer([x.reshape(tuple(x.shape[self.axis] if a == new_axis else s for a,s in enumerate(arg))) for x in self.lbs], new_axis)
|
||||
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
||||
|
||||
Reference in New Issue
Block a user