mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
simplify View.permute arg check [run_process_replay] (#5218)
it checks if `axis` is a valid permutation, which is the same as `sorted(axis) == list(range(len(self.shape)))`
This commit is contained in:
@@ -245,8 +245,7 @@ class View:
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def permute(self, axis: Tuple[int, ...]) -> View:
|
||||
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
||||
assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
|
||||
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
|
||||
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user