simplify View.permute arg check [run_process_replay] (#5218)

it checks if `axis` is a valid permutation, which is the same as `sorted(axis) == list(range(len(self.shape)))`
This commit is contained in:
chenyu
2024-06-28 16:18:46 -04:00
committed by GitHub
parent 80ac21200b
commit 7ba4938510

View File

@@ -245,8 +245,7 @@ class View:
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def permute(self, axis: Tuple[int, ...]) -> View:
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
tuple(self.mask[a] for a in axis) if self.mask is not None else None)