mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Use generators instead of lists in anys and alls (#1111)
* Use generators in any(..) instead of lists for better best-case * Use generators in all(...) instead of lists * enable R1729 in .pylintrc * revert import sorting --------- Co-authored-by: Anselm Coogan <anselm@scandit.com>
This commit is contained in:
@@ -64,7 +64,7 @@ disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=c-extension-no-member
|
||||
enable=c-extension-no-member,use-a-generator
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
@@ -144,7 +144,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
elif n.op_type == "Resize":
|
||||
# TODO: this is handcoded for YOLOv8
|
||||
scales = safe_numpy(inp[2])
|
||||
assert all([int(x) == x and x >= 1 for x in scales])
|
||||
assert all(int(x) == x and x >= 1 for x in scales)
|
||||
ret = inp[0].reshape([val for pair in zip(inp[0].shape, [1] * len(scales)) for val in pair])
|
||||
ret = ret.expand([val for pair in zip(inp[0].shape, [int(x) for x in scales]) for val in pair])
|
||||
ret = ret.reshape([x*y for x,y in zip(inp[0].shape, [int(x) for x in scales])])
|
||||
@@ -157,7 +157,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis)
|
||||
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
|
||||
elif n.op_type in ["Add", "Sub", "Mul", "Pow"]:
|
||||
if all([isinstance(x, Tensor) for x in inp]) and (len(inp[0].shape) != len(inp[1].shape)) and (prod(inp[0].shape) == prod(inp[1].shape)):
|
||||
if all(isinstance(x, Tensor) for x in inp) and (len(inp[0].shape) != len(inp[1].shape)) and (prod(inp[0].shape) == prod(inp[1].shape)):
|
||||
inp[1] = inp[1].reshape(inp[0].shape)
|
||||
# TODO: is this right?
|
||||
if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))])
|
||||
|
||||
@@ -55,7 +55,7 @@ class RetinaNet:
|
||||
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
|
||||
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
|
||||
self.num_anchors, self.num_classes = num_anchors, num_classes
|
||||
assert len(scales) == len(aspect_ratios) and all([self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)])
|
||||
assert len(scales) == len(aspect_ratios) and all(self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios))
|
||||
|
||||
self.backbone = ResNetFPN(backbone)
|
||||
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
|
||||
|
||||
@@ -55,7 +55,7 @@ def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
|
||||
|
||||
def to_float4(x:List[Token]) -> Optional[Token]:
|
||||
if all_same(x): return x[0]
|
||||
if all_same([y.name for y in x]) and all([y.dtype == dtypes._float4 and y.offset == i for i,y in enumerate(x)]):
|
||||
if all_same([y.name for y in x]) and all(y.dtype == dtypes._float4 and y.offset == i for i,y in enumerate(x)):
|
||||
return Token(x[0].name, dtypes._float4)
|
||||
return None
|
||||
|
||||
@@ -68,7 +68,7 @@ def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
|
||||
new_values = []
|
||||
for i in range(0, len(idxs), 4):
|
||||
nv = [to_float4([v[j] for j in idxs[i:i+4]]) for v in values]
|
||||
if any([x is None for x in nv]): break
|
||||
if any(x is None for x in nv): break
|
||||
new_idxs.append(idxs[i:i+4])
|
||||
new_values.append(nv)
|
||||
if len(new_values) == len(idxs)//4:
|
||||
@@ -437,7 +437,7 @@ class Linearizer:
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
if self.shape_len == 0: return
|
||||
all_ones = [all([st.shape[i]==1 for st in self.sts]) for i in range(self.shape_len)]
|
||||
all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)]
|
||||
# keep at least 1 one
|
||||
if all(all_ones): all_ones[-1] = False
|
||||
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
||||
@@ -495,7 +495,7 @@ class Linearizer:
|
||||
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]):
|
||||
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
||||
self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(sz)
|
||||
break
|
||||
|
||||
@@ -17,7 +17,7 @@ def argfix(*x):
|
||||
except IndexError: return tuple()
|
||||
return tuple(x)
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def all_same(items): return all([x == items[0] for x in items]) if len(items) > 1 else True
|
||||
def all_same(items): return all(x == items[0] for x in items)
|
||||
def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
|
||||
def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
|
||||
def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)]
|
||||
|
||||
@@ -117,7 +117,7 @@ class LazyBuffer:
|
||||
for x in self.op.buffers: x.realize()
|
||||
|
||||
# HACK: image shape can be wrong, hot cast it back to a normal float
|
||||
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any([self.shape[x]%4 == 0 for x in self.st.unit_stride_axes()])):
|
||||
if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
||||
if self.op.op == MovementOps.RESHAPE:
|
||||
# put CAST before the final RESHAPE
|
||||
self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg)
|
||||
@@ -190,7 +190,7 @@ class LazyBuffer:
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
|
||||
|
||||
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all([b == 0 and e == 0 for b,e in arg]): return self
|
||||
if all(b == 0 and e == 0 for b,e in arg): return self
|
||||
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
|
||||
|
||||
@@ -224,7 +224,7 @@ class LazyBuffer:
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
|
||||
|
||||
def shrink(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all([b - a == s for s, (a, b) in zip(self.shape, arg)]): return self
|
||||
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
|
||||
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
|
||||
|
||||
@@ -254,7 +254,7 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
||||
mops.append((bx.op.op, bx.op.arg))
|
||||
bx = cast(LazyBuffer, bx.op.src[0])
|
||||
# NOTE: can't push pads with a div
|
||||
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all([x[0] != MovementOps.PAD for x in mops]) or all([x.op != BinaryOps.DIV for x in bx.op.get_lazyops()])):
|
||||
if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in bx.op.get_lazyops())):
|
||||
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
@@ -268,7 +268,7 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional
|
||||
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(DType, arg)
|
||||
|
||||
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
|
||||
if PUSH_CONTIGUOUS and any([not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs]):
|
||||
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
|
||||
new_srcs: List[LazyBuffer] = []
|
||||
for x in srcs:
|
||||
if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
|
||||
|
||||
@@ -21,7 +21,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
|
||||
return ret
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all([s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape))])
|
||||
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
@@ -196,7 +196,7 @@ class ShapeTracker:
|
||||
return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx))
|
||||
|
||||
def needs_valid(self) -> bool:
|
||||
return any([v.mask is not None for v in self.views])
|
||||
return any(v.mask is not None for v in self.views)
|
||||
|
||||
# *** under this line are the movement ops ***
|
||||
|
||||
@@ -211,7 +211,7 @@ class ShapeTracker:
|
||||
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]):
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
if any([b or e for b, e in arg]):
|
||||
if any(b or e for b, e in arg):
|
||||
zvarg, mask = get_pad_args(self.shape, arg)
|
||||
self.__unsafe_resize(zvarg, mask=mask)
|
||||
return self
|
||||
|
||||
@@ -108,7 +108,7 @@ class Node:
|
||||
def ands(nodes:List[Node]) -> Node:
|
||||
if not nodes: return NumNode(1)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
if any([x.min == x.max == 0 for x in nodes]): return NumNode(0)
|
||||
if any(x.min == x.max == 0 for x in nodes): return NumNode(0)
|
||||
|
||||
# filter 1s
|
||||
nodes = [x for x in nodes if x.min != x.max]
|
||||
@@ -192,7 +192,7 @@ class SumNode(RedNode):
|
||||
factor_term = [x.a * x.b//b if isinstance(x, MulNode) else NumNode(x.b//b) for x in factors]
|
||||
if nofactor_mul and not nofactor_nonmul:
|
||||
gcds = [gcd(x.b, b) for x in nofactor_mul]
|
||||
if (t := min(gcds)) > 1 and all([x.b%t == 0 for x in nofactor_mul]):
|
||||
if (t := min(gcds)) > 1 and all(x.b%t == 0 for x in nofactor_mul):
|
||||
nofactor_term = [Node.sum([x.a * x.b//t for x in nofactor_mul if isinstance(x, MulNode)])//(b//t)] # mypy wants the isinstance
|
||||
else:
|
||||
nofactor_term = [Node.sum(nofactor_mul)//b] if nofactor_mul else []
|
||||
|
||||
@@ -234,8 +234,8 @@ class Tensor:
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any([x != (0,0) for x in arg]) else self
|
||||
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any([x != (0,s) for x,s in zip(arg, self.shape)]) else self
|
||||
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self
|
||||
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
@@ -314,7 +314,7 @@ class Tensor:
|
||||
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
assert all([len(y.shape) == len(self.shape) and all([y.shape[i] == s for i,s in enumerate(self.shape) if i != dim]) for y in args])
|
||||
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
||||
catargs = [self] + list(args)
|
||||
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
|
||||
shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])]
|
||||
@@ -438,7 +438,7 @@ class Tensor:
|
||||
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
|
||||
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
|
||||
stride = make_pair(stride, len(HW))
|
||||
if any([s>1 for s in stride]):
|
||||
if any(s>1 for s in stride):
|
||||
x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
|
||||
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
|
||||
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
|
||||
Reference in New Issue
Block a user