typing fixup

This commit is contained in:
George Hotz
2023-02-27 09:52:04 -08:00
parent 9aaa7edd74
commit e74779f19d
8 changed files with 48 additions and 35 deletions

View File

@@ -1,3 +1,4 @@
from typing import Optional
from tinygrad.tensor import Tensor
class BatchNorm2d:
@@ -59,15 +60,16 @@ class Linear:
class GroupNorm:
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine
self.weight, self.bias = (Tensor.ones(num_channels), Tensor.zeros(num_channels)) if affine else (None, None)
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None
self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
def __call__(self, x:Tensor):
# reshape for layernorm to work as group norm
# subtract mean and divide stddev
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
if not self.affine: return x
if self.weight is None or self.bias is None: return x
# elementwise_affine on channels
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)