mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
remove view_from_shape (#1448)
This commit is contained in:
@@ -85,12 +85,7 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
||||
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
strides = [1] if shape else []
|
||||
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
||||
return tuple([st if s != 1 else 0 for st, s in zip(strides, shape)])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def view_from_shape(shape:Tuple[Union[Node, int], ...]) -> View:
|
||||
assert all(is_sym_int(x) for x in shape)
|
||||
return View(tuple(shape), strides_for_shape(shape))
|
||||
return filter_strides(shape, tuple(strides))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
@@ -119,7 +114,7 @@ def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
|
||||
return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False
|
||||
|
||||
new_view = View(new_shape, strides_for_shape(new_shape))
|
||||
new_view = View(new_shape)
|
||||
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
|
||||
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
|
||||
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
|
||||
@@ -136,7 +131,7 @@ def get_unsafe_resize_offset(strides, arg):
|
||||
class ShapeTracker:
|
||||
__slots__ = "views"
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None):
|
||||
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [view_from_shape(shape)])
|
||||
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [View(shape)])
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user