Use generators instead of lists in anys and alls (#1111)

* Use generators in any(..) instead of lists for better best-case

* Use generators in all(...) instead of lists

* enable R1729 in .pylintrc

* revert import sorting

---------

Co-authored-by: Anselm Coogan <anselm@scandit.com>
This commit is contained in:
Anselm Coogan
2023-07-04 01:06:06 +02:00
committed by GitHub
parent fd98f6cffa
commit a22aad7d32
9 changed files with 23 additions and 23 deletions

View File

@@ -64,7 +64,7 @@ disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
enable=c-extension-no-member,use-a-generator
[REPORTS]

View File

@@ -144,7 +144,7 @@ def get_run_onnx(onnx_model: ModelProto):
elif n.op_type == "Resize":
# TODO: this is handcoded for YOLOv8
scales = safe_numpy(inp[2])
assert all([int(x) == x and x >= 1 for x in scales])
assert all(int(x) == x and x >= 1 for x in scales)
ret = inp[0].reshape([val for pair in zip(inp[0].shape, [1] * len(scales)) for val in pair])
ret = ret.expand([val for pair in zip(inp[0].shape, [int(x) for x in scales]) for val in pair])
ret = ret.reshape([x*y for x,y in zip(inp[0].shape, [int(x) for x in scales])])
@@ -157,7 +157,7 @@ def get_run_onnx(onnx_model: ModelProto):
ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis)
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
elif n.op_type in ["Add", "Sub", "Mul", "Pow"]:
if all([isinstance(x, Tensor) for x in inp]) and (len(inp[0].shape) != len(inp[1].shape)) and (prod(inp[0].shape) == prod(inp[1].shape)):
if all(isinstance(x, Tensor) for x in inp) and (len(inp[0].shape) != len(inp[1].shape)) and (prod(inp[0].shape) == prod(inp[1].shape)):
inp[1] = inp[1].reshape(inp[0].shape)
# TODO: is this right?
if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))])

View File

@@ -55,7 +55,7 @@ class RetinaNet:
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
self.num_anchors, self.num_classes = num_anchors, num_classes
assert len(scales) == len(aspect_ratios) and all([self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)])
assert len(scales) == len(aspect_ratios) and all(self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios))
self.backbone = ResNetFPN(backbone)
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)

View File

@@ -55,7 +55,7 @@ def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
def to_float4(x:List[Token]) -> Optional[Token]:
if all_same(x): return x[0]
if all_same([y.name for y in x]) and all([y.dtype == dtypes._float4 and y.offset == i for i,y in enumerate(x)]):
if all_same([y.name for y in x]) and all(y.dtype == dtypes._float4 and y.offset == i for i,y in enumerate(x)):
return Token(x[0].name, dtypes._float4)
return None
@@ -68,7 +68,7 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
new_values = []
for i in range(0, len(idxs), 4):
nv = [to_float4([v[j] for j in idxs[i:i+4]]) for v in values]
if any([x is None for x in nv]): break
if any(x is None for x in nv): break
new_idxs.append(idxs[i:i+4])
new_values.append(nv)
if len(new_values) == len(idxs)//4:
@@ -437,7 +437,7 @@ class Linearizer:
# remove places where the shape is all ones
# TODO: this should be factored in to multi shape stride
if self.shape_len == 0: return
all_ones = [all([st.shape[i]==1 for st in self.sts]) for i in range(self.shape_len)]
all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)]
# keep at least 1 one
if all(all_ones): all_ones[-1] = False
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
@@ -495,7 +495,7 @@ class Linearizer:
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]):
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(sz)
break

View File

@@ -17,7 +17,7 @@ def argfix(*x):
except IndexError: return tuple()
return tuple(x)
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all([x == items[0] for x in items]) if len(items) > 1 else True
def all_same(items): return all(x == items[0] for x in items)
def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)]

View File

@@ -117,7 +117,7 @@ class LazyBuffer:
for x in self.op.buffers: x.realize()
# HACK: image shape can be wrong, hot cast it back to a normal float
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any([self.shape[x]%4 == 0 for x in self.st.unit_stride_axes()])):
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
if self.op.op == MovementOps.RESHAPE:
# put CAST before the final RESHAPE
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg)
@@ -190,7 +190,7 @@ class LazyBuffer:
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all([b == 0 and e == 0 for b,e in arg]): return self
if all(b == 0 and e == 0 for b,e in arg): return self
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
@@ -224,7 +224,7 @@ class LazyBuffer:
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
def shrink(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all([b - a == s for s, (a, b) in zip(self.shape, arg)]): return self
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
@@ -254,7 +254,7 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
mops.append((bx.op.op, bx.op.arg))
bx = cast(LazyBuffer, bx.op.src[0])
# NOTE: can't push pads with a div
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all([x[0] != MovementOps.PAD for x in mops]) or all([x.op != BinaryOps.DIV for x in bx.op.get_lazyops()])):
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in bx.op.get_lazyops())):
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
else:
new_srcs.append(x)
@@ -268,7 +268,7 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(DType, arg)
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
if PUSH_CONTIGUOUS and any([not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs]):
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
new_srcs: List[LazyBuffer] = []
for x in srcs:
if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:

View File

@@ -21,7 +21,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
return ret
@functools.lru_cache(maxsize=None)
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all([s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape))])
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
@functools.lru_cache(maxsize=None)
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
@@ -196,7 +196,7 @@ class ShapeTracker:
return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx))
def needs_valid(self) -> bool:
return any([v.mask is not None for v in self.views])
return any(v.mask is not None for v in self.views)
# *** under this line are the movement ops ***
@@ -211,7 +211,7 @@ class ShapeTracker:
def pad(self, arg: Tuple[Tuple[int, int], ...]):
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
if any([b or e for b, e in arg]):
if any(b or e for b, e in arg):
zvarg, mask = get_pad_args(self.shape, arg)
self.__unsafe_resize(zvarg, mask=mask)
return self

View File

@@ -108,7 +108,7 @@ class Node:
def ands(nodes:List[Node]) -> Node:
if not nodes: return NumNode(1)
if len(nodes) == 1: return nodes[0]
if any([x.min == x.max == 0 for x in nodes]): return NumNode(0)
if any(x.min == x.max == 0 for x in nodes): return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]
@@ -192,7 +192,7 @@ class SumNode(RedNode):
factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors]
if nofactor_mul and not nofactor_nonmul:
gcds = [gcd(x.b, b) for x in nofactor_mul]
if (t := min(gcds)) > 1 and all([x.b%t == 0 for x in nofactor_mul]):
if (t := min(gcds)) > 1 and all(x.b%t == 0 for x in nofactor_mul):
nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance
else:
nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else []

View File

@@ -234,8 +234,8 @@ class Tensor:
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any([x != (0,0) for x in arg]) else self
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any([x != (0,s) for x,s in zip(arg, self.shape)]) else self
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
# ***** movement hlops *****
@@ -314,7 +314,7 @@ class Tensor:
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
assert all([len(y.shape) == len(self.shape) and all([y.shape[i] == s for i,s in enumerate(self.shape) if i != dim]) for y in args])
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
catargs = [self] + list(args)
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])]
@@ -438,7 +438,7 @@ class Tensor:
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
stride = make_pair(stride, len(HW))
if any([s>1 for s in stride]):
if any(s>1 for s in stride):
x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])