mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
[perf] Replace more list comprehension with * (#1106)
* [perf] Replace more list comprehension with * * comeback * final fix? * blind me * kill me * ? * rev * [none]
This commit is contained in:
@@ -83,7 +83,7 @@ class GroupNorm:
|
||||
|
||||
if self.weight is None or self.bias is None: return x
|
||||
# elementwise_affine on channels
|
||||
return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)])
|
||||
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
||||
|
||||
class InstanceNorm:
|
||||
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
|
||||
@@ -94,7 +94,7 @@ class InstanceNorm:
|
||||
def __call__(self, x:Tensor):
|
||||
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
if self.weight is None or self.bias is None: return x
|
||||
return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)])
|
||||
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
||||
|
||||
class LayerNorm:
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
|
||||
@@ -25,11 +25,7 @@ def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: retur
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
new_strides = []
|
||||
for stride, shp in zip(strides, shape):
|
||||
if shp != 1: new_strides.append(stride)
|
||||
else: new_strides.append(0)
|
||||
return tuple(new_strides)
|
||||
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
|
||||
|
||||
class View:
|
||||
__slots__ = "shape", "strides", "offset", "mask", "shape_strides", "contiguous"
|
||||
@@ -108,7 +104,7 @@ def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
if mask:
|
||||
for x,y in zip(shape, mask):
|
||||
if x == 1 and y != (0, 1):
|
||||
new_mask_tuple = tuple([(0,0) for _ in new_shape])
|
||||
new_mask_tuple = ((0,0),) * len(new_shape)
|
||||
break
|
||||
else:
|
||||
new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1]
|
||||
@@ -162,7 +158,7 @@ class ShapeTracker:
|
||||
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
||||
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx, valid = self.expr_idxs(idxs)
|
||||
ret: List[Optional[int]] = [None for _ in self.views[-1].shape]
|
||||
ret: List[Optional[int]] = [None] * len(self.views[-1].shape)
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
|
||||
Reference in New Issue
Block a user