[perf] Replace more list comprehension with * (#1106)

* [perf] Replace more list comprehension with *

* comeback

* final fix?

* blind me

* kill me

* ?

* rev

* [none]
This commit is contained in:
Mike Ovyan
2023-07-03 20:49:23 +03:00
committed by GitHub
parent 2071e53da8
commit 651d080594
2 changed files with 5 additions and 9 deletions

View File

@@ -83,7 +83,7 @@ class GroupNorm:
if self.weight is None or self.bias is None: return x
# elementwise_affine on channels
return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)])
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
class InstanceNorm:
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
@@ -94,7 +94,7 @@ class InstanceNorm:
def __call__(self, x:Tensor):
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)])
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
class LayerNorm:
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):

View File

@@ -25,11 +25,7 @@ def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: retur
@functools.lru_cache(maxsize=None)
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
new_strides = []
for stride, shp in zip(strides, shape):
if shp != 1: new_strides.append(stride)
else: new_strides.append(0)
return tuple(new_strides)
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
class View:
__slots__ = "shape", "strides", "offset", "mask", "shape_strides", "contiguous"
@@ -108,7 +104,7 @@ def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
if mask:
for x,y in zip(shape, mask):
if x == 1 and y != (0, 1):
new_mask_tuple = tuple([(0,0) for _ in new_shape])
new_mask_tuple = ((0,0),) * len(new_shape)
break
else:
new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1]
@@ -162,7 +158,7 @@ class ShapeTracker:
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[int]] = [None for _ in self.views[-1].shape]
ret: List[Optional[int]] = [None] * len(self.views[-1].shape)
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
ret[idxs.index(this_dim.a)] = this_dim.b