diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index d5fa8c4e2a..7428f453aa 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -85,12 +85,7 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node: def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: strides = [1] if shape else [] for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides - return tuple([st if s != 1 else 0 for st, s in zip(strides, shape)]) - -@functools.lru_cache(maxsize=None) -def view_from_shape(shape:Tuple[Union[Node, int], ...]) -> View: - assert all(is_sym_int(x) for x in shape) - return View(tuple(shape), strides_for_shape(shape)) + return filter_strides(shape, tuple(strides)) @functools.lru_cache(maxsize=None) def merge_views(vm2:View, vm1:View) -> Optional[View]: @@ -119,7 +114,7 @@ def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]: new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape]) return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False - new_view = View(new_shape, strides_for_shape(new_shape)) + new_view = View(new_shape) if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}") @@ -136,7 +131,7 @@ def get_unsafe_resize_offset(strides, arg): class ShapeTracker: __slots__ = "views" def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None): - self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [view_from_shape(shape)]) + self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [View(shape)]) def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})" def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])