add typing

This commit is contained in:
George Hotz
2023-02-28 10:54:46 -08:00
parent 922f96e527
commit 3c8da6bd03
8 changed files with 49 additions and 47 deletions

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor
class BatchNorm2d:
@@ -59,7 +59,7 @@ class Linear:
return x.linear(self.weight.transpose(), self.bias)
class GroupNorm:
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None
self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
@@ -74,7 +74,7 @@ class GroupNorm:
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)
class LayerNorm:
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None)