merge_views is very powerful

This commit is contained in:
George Hotz
2023-03-03 22:53:59 -08:00
parent b5b4edf59b
commit aef336c079

View File

@@ -1,7 +1,7 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools
from typing import Tuple, Union, List, Optional
from typing import Tuple, Union, List, Optional, cast
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.symbolic import Variable, MulNode, NumNode
@@ -17,8 +17,6 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
return ret
class View:
__slots__ = ('shape', 'strides', 'offset', 'shape_strides', 'contiguous')
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0):
self.shape, self.strides, self.offset = shape, tuple(stride if shp != 1 else 0 for stride,shp in zip(strides, shape)), offset
self.shape_strides = to_shape_strides(self.shape, self.strides)
@@ -41,8 +39,6 @@ class View:
return Variable.sum([Variable.num(self.offset+offset)] + [Variable(idx, 0, sh-1)*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
class ZeroView:
__slots__ = ('old_shape', 'arg', 'shape', 'contiguous', 'strides', 'offset')
def __init__(self, old_shape:Tuple[int, ...], arg):
self.old_shape, self.arg = old_shape, arg
self.shape : Tuple[int, ...] = tuple([y-x for x,y in self.arg])
@@ -98,8 +94,6 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
return View(vm1.shape, tuple(new_strides), new_offset.b) if len(new_strides) == len(vm1.strides) else None
class ShapeTracker:
__slots__ = ('views')
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[ViewTypes]]=None):
self.views : List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)])
def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})"
@@ -150,42 +144,13 @@ class ShapeTracker:
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
# check if this is adding or removing 1s (only)
if tuple(x for x in self.shape if x != 1) == tuple(x for x in new_shape if x != 1):
old_strides = [y for x,y in zip(self.shape, self.strides) if x != 1]
new_strides_tuple = tuple(0 if x == 1 else old_strides.pop(0) for x in new_shape)
self.views[-1] = View(new_shape, new_strides_tuple, self.offset)
return self
# check if the new dimensions factorize from the old ones
# NOTE: if you don't make a copy here, the list is popped in the lrucache
min_shape_strides = to_shape_strides(self.shape, self.strides)[:]
curr_dim, curr_stride = min_shape_strides.pop(0)
new_strides : List[int] = []
for s in new_shape:
if curr_dim%s == 0:
curr_dim //= s
new_strides.append(curr_stride * curr_dim)
if curr_dim == 1:
if len(min_shape_strides) == 0:
# there might still be 1s in the shape
while len(new_strides) != len(new_shape):
assert new_shape[len(new_strides)] == 1
new_strides.append(1)
break
curr_dim, curr_stride = min_shape_strides.pop(0)
else:
break # didn't factorize
if len(new_shape) == len(new_strides):
self.views[-1] = View(new_shape, tuple(new_strides), self.offset)
return self
view = View(new_shape, strides_for_shape(new_shape))
if self.contiguous:
self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
else:
self.views.append(view)
# NOTE: the last view in self.views is never a ZeroView
if (merged_view := merge_views(cast(View, self.views[-1]), view)) is not None: self.views[-1] = merged_view
else: self.views.append(view)
return self
def permute(self, axis : Tuple[int, ...]) -> ShapeTracker: