diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 396ec7f8b5..de92b2f4d9 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -12,8 +12,8 @@ def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, @functools.lru_cache(maxsize=None) def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: strides = [1] if shape else [] - for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides - return filter_strides(shape, tuple(strides)) + for d in reversed(shape[1:]): strides.append(d*strides[-1]) + return filter_strides(shape, tuple(reversed(strides))) @dataclass(frozen=True) class View: @@ -27,7 +27,7 @@ class View: @functools.lru_cache(maxsize=None) def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None): strides = filter_strides(shape, strides) if strides else strides_for_shape(shape) - contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape))) + contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape) return View(shape, strides, offset, mask, contiguous) def vars(self) -> Set[Variable]: