Use generators instead of lists in anys and alls (#1111)

* Use generators in any(..) instead of lists for better best-case

* Use generators in all(...) instead of lists

* enable R1729 in .pylintrc

* revert import sorting

---------

Co-authored-by: Anselm Coogan <anselm@scandit.com>
This commit is contained in:
Anselm Coogan
2023-07-04 01:06:06 +02:00
committed by GitHub
parent fd98f6cffa
commit a22aad7d32
9 changed files with 23 additions and 23 deletions

View File

@@ -234,8 +234,8 @@ class Tensor:
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any([x != (0,0) for x in arg]) else self
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any([x != (0,s) for x,s in zip(arg, self.shape)]) else self
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
# ***** movement hlops *****
@@ -314,7 +314,7 @@ class Tensor:
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
assert all([len(y.shape) == len(self.shape) and all([y.shape[i] == s for i,s in enumerate(self.shape) if i != dim]) for y in args])
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
catargs = [self] + list(args)
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])]
@@ -438,7 +438,7 @@ class Tensor:
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
stride = make_pair(stride, len(HW))
if any([s>1 for s in stride]):
if any(s>1 for s in stride):
x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])