minor cleanup for View strides (#2404)

This commit is contained in:
chenyu
2023-11-23 13:40:01 -05:00
committed by GitHub
parent 64aa2f4156
commit b27c845531

View File

@@ -12,8 +12,8 @@ def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int,
@functools.lru_cache(maxsize=None)
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
strides = [1] if shape else []
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
return filter_strides(shape, tuple(strides))
for d in reversed(shape[1:]): strides.append(d*strides[-1])
return filter_strides(shape, tuple(reversed(strides)))
@dataclass(frozen=True)
class View:
@@ -27,7 +27,7 @@ class View:
@functools.lru_cache(maxsize=None)
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
strides = filter_strides(shape, strides) if strides else strides_for_shape(shape)
contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape)))
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
return View(shape, strides, offset, mask, contiguous)
def vars(self) -> Set[Variable]: