mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
minor cleanup for View strides (#2404)
This commit is contained in:
@@ -12,8 +12,8 @@ def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int,
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
strides = [1] if shape else []
|
||||
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
||||
return filter_strides(shape, tuple(strides))
|
||||
for d in reversed(shape[1:]): strides.append(d*strides[-1])
|
||||
return filter_strides(shape, tuple(reversed(strides)))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
@@ -27,7 +27,7 @@ class View:
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
||||
strides = filter_strides(shape, strides) if strides else strides_for_shape(shape)
|
||||
contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape)))
|
||||
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
||||
return View(shape, strides, offset, mask, contiguous)
|
||||
|
||||
def vars(self) -> Set[Variable]:
|
||||
|
||||
Reference in New Issue
Block a user