mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
typing fixup
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from typing import Optional
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class BatchNorm2d:
|
||||
@@ -59,15 +60,16 @@ class Linear:
|
||||
|
||||
class GroupNorm:
|
||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
||||
self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine
|
||||
self.weight, self.bias = (Tensor.ones(num_channels), Tensor.zeros(num_channels)) if affine else (None, None)
|
||||
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
||||
self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
||||
self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
# reshape for layernorm to work as group norm
|
||||
# subtract mean and divide stddev
|
||||
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
|
||||
if not self.affine: return x
|
||||
if self.weight is None or self.bias is None: return x
|
||||
# elementwise_affine on channels
|
||||
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user