From 651d080594771df43504953bf3ccf8aabe2e3d83 Mon Sep 17 00:00:00 2001 From: Mike Ovyan Date: Mon, 3 Jul 2023 20:49:23 +0300 Subject: [PATCH] [perf] Replace more list comprehension with * (#1106) * [perf] Replace more list comprehension with * * comeback * final fix? * blind me * kill me * ? * rev * [none] --- tinygrad/nn/__init__.py | 4 ++-- tinygrad/shape/shapetracker.py | 10 +++------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 65dbd2b9c1..67cc62c2ca 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -83,7 +83,7 @@ class GroupNorm: if self.weight is None or self.bias is None: return x # elementwise_affine on channels - return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2)) class InstanceNorm: def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True): @@ -94,7 +94,7 @@ class InstanceNorm: def __call__(self, x:Tensor): x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape) if self.weight is None or self.bias is None: return x - return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2)) class LayerNorm: def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True): diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index b02c150e2a..de4ba019de 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -25,11 +25,7 @@ def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: retur @functools.lru_cache(maxsize=None) def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]: - new_strides = [] - for stride, shp in zip(strides, shape): - if shp != 1: new_strides.append(stride) - else: new_strides.append(0) - return tuple(new_strides) + return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape)) class View: __slots__ = "shape", "strides", "offset", "mask", "shape_strides", "contiguous" @@ -108,7 +104,7 @@ def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]: if mask: for x,y in zip(shape, mask): if x == 1 and y != (0, 1): - new_mask_tuple = tuple([(0,0) for _ in new_shape]) + new_mask_tuple = ((0,0),) * len(new_shape) break else: new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1] @@ -162,7 +158,7 @@ class ShapeTracker: if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] idx, valid = self.expr_idxs(idxs) - ret: List[Optional[int]] = [None for _ in self.views[-1].shape] + ret: List[Optional[int]] = [None] * len(self.views[-1].shape) for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]): if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable): ret[idxs.index(this_dim.a)] = this_dim.b