mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
add typing
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class BatchNorm2d:
|
||||
@@ -59,7 +59,7 @@ class Linear:
|
||||
return x.linear(self.weight.transpose(), self.bias)
|
||||
|
||||
class GroupNorm:
|
||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
||||
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
|
||||
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
||||
self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
||||
self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
||||
@@ -74,7 +74,7 @@ class GroupNorm:
|
||||
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)
|
||||
|
||||
class LayerNorm:
|
||||
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine
|
||||
self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None)
|
||||
|
||||
Reference in New Issue
Block a user