mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
merge_views is very powerful
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import Tuple, Union, List, Optional
|
||||
from typing import Tuple, Union, List, Optional, cast
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode
|
||||
|
||||
@@ -17,8 +17,6 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
|
||||
return ret
|
||||
|
||||
class View:
|
||||
__slots__ = ('shape', 'strides', 'offset', 'shape_strides', 'contiguous')
|
||||
|
||||
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0):
|
||||
self.shape, self.strides, self.offset = shape, tuple(stride if shp != 1 else 0 for stride,shp in zip(strides, shape)), offset
|
||||
self.shape_strides = to_shape_strides(self.shape, self.strides)
|
||||
@@ -41,8 +39,6 @@ class View:
|
||||
return Variable.sum([Variable.num(self.offset+offset)] + [Variable(idx, 0, sh-1)*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
|
||||
|
||||
class ZeroView:
|
||||
__slots__ = ('old_shape', 'arg', 'shape', 'contiguous', 'strides', 'offset')
|
||||
|
||||
def __init__(self, old_shape:Tuple[int, ...], arg):
|
||||
self.old_shape, self.arg = old_shape, arg
|
||||
self.shape : Tuple[int, ...] = tuple([y-x for x,y in self.arg])
|
||||
@@ -98,8 +94,6 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
return View(vm1.shape, tuple(new_strides), new_offset.b) if len(new_strides) == len(vm1.strides) else None
|
||||
|
||||
class ShapeTracker:
|
||||
__slots__ = ('views')
|
||||
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[ViewTypes]]=None):
|
||||
self.views : List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)])
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})"
|
||||
@@ -150,42 +144,13 @@ class ShapeTracker:
|
||||
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
|
||||
# check if this is adding or removing 1s (only)
|
||||
if tuple(x for x in self.shape if x != 1) == tuple(x for x in new_shape if x != 1):
|
||||
old_strides = [y for x,y in zip(self.shape, self.strides) if x != 1]
|
||||
new_strides_tuple = tuple(0 if x == 1 else old_strides.pop(0) for x in new_shape)
|
||||
self.views[-1] = View(new_shape, new_strides_tuple, self.offset)
|
||||
return self
|
||||
|
||||
# check if the new dimensions factorize from the old ones
|
||||
# NOTE: if you don't make a copy here, the list is popped in the lrucache
|
||||
min_shape_strides = to_shape_strides(self.shape, self.strides)[:]
|
||||
curr_dim, curr_stride = min_shape_strides.pop(0)
|
||||
new_strides : List[int] = []
|
||||
for s in new_shape:
|
||||
if curr_dim%s == 0:
|
||||
curr_dim //= s
|
||||
new_strides.append(curr_stride * curr_dim)
|
||||
if curr_dim == 1:
|
||||
if len(min_shape_strides) == 0:
|
||||
# there might still be 1s in the shape
|
||||
while len(new_strides) != len(new_shape):
|
||||
assert new_shape[len(new_strides)] == 1
|
||||
new_strides.append(1)
|
||||
break
|
||||
curr_dim, curr_stride = min_shape_strides.pop(0)
|
||||
else:
|
||||
break # didn't factorize
|
||||
|
||||
if len(new_shape) == len(new_strides):
|
||||
self.views[-1] = View(new_shape, tuple(new_strides), self.offset)
|
||||
return self
|
||||
|
||||
view = View(new_shape, strides_for_shape(new_shape))
|
||||
if self.contiguous:
|
||||
self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
|
||||
else:
|
||||
self.views.append(view)
|
||||
# NOTE: the last view in self.views is never a ZeroView
|
||||
if (merged_view := merge_views(cast(View, self.views[-1]), view)) is not None: self.views[-1] = merged_view
|
||||
else: self.views.append(view)
|
||||
return self
|
||||
|
||||
def permute(self, axis : Tuple[int, ...]) -> ShapeTracker:
|
||||
|
||||
Reference in New Issue
Block a user