mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update layers to support float32 weight loading + float16 training
This commit is contained in:
@@ -68,27 +68,42 @@ class LayerNormBert:
|
||||
return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float))
|
||||
|
||||
class FrozenBatchNorm2d(nn.BatchNorm2d):
|
||||
def __init__(self, num_features:int):
|
||||
def __init__(self, num_features:int, affine:bool=True, track_running_stats:bool=True):
|
||||
super().__init__(num_features)
|
||||
self.weight.requires_grad = False
|
||||
self.bias.requires_grad = False
|
||||
|
||||
if affine:
|
||||
self.weight = self.weight.cast(dtypes.float32)
|
||||
self.bias = self.weight.cast(dtypes.float32)
|
||||
|
||||
self.weight.requires_grad = False
|
||||
self.bias.requires_grad = False
|
||||
|
||||
if track_running_stats: self.running_mean, self.running_var = self.running_mean.cast(dtypes.float32), self.running_var.cast(dtypes.float32)
|
||||
|
||||
class Conv2dNormal(nn.Conv2d):
|
||||
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
|
||||
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
|
||||
bias:bool=True, prior_prob:float|None=None):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
self.weight = Tensor.normal(*self.weight.shape, std=0.01)
|
||||
self.weight = Tensor.normal(*self.weight.shape, std=0.01, dtype=dtypes.float32)
|
||||
if bias:
|
||||
if prior_prob:
|
||||
prior_prob = Tensor(prior_prob, device=self.bias.device, dtype=self.bias.dtype).expand(*self.bias.shape)
|
||||
prior_prob = Tensor(prior_prob, device=self.bias.device, dtype=dtypes.float32).expand(*self.bias.shape)
|
||||
self.bias = -(((1 - prior_prob) / prior_prob).log())
|
||||
else: self.bias = Tensor.zeros_like(self.bias)
|
||||
else: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
||||
groups=self.groups, stride=self.stride, padding=self.padding)
|
||||
|
||||
class Conv2dKaimingUniform(nn.Conv2d):
|
||||
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
|
||||
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
|
||||
bias:bool=True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
self.weight = Tensor.kaiming_uniform(*self.weight.shape, a=1)
|
||||
if bias: self.bias = Tensor.zeros_like(self.bias)
|
||||
self.weight = Tensor.kaiming_uniform(*self.weight.shape, a=1, dtype=dtypes.float32)
|
||||
if bias: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
||||
groups=self.groups, stride=self.stride, padding=self.padding)
|
||||
|
||||
@@ -345,7 +345,7 @@ def train_resnet():
|
||||
def train_retinanet():
|
||||
from contextlib import redirect_stdout
|
||||
from examples.mlperf.dataloader import batch_load_retinanet
|
||||
from examples.mlperf.initializers import FrozenBatchNorm2d, Conv2dNormal, Conv2dKaimingUniform
|
||||
from examples.mlperf.initializers import FrozenBatchNorm2d, Conv2dNormal, Conv2dKaimingUniform, Conv2dHeNormal
|
||||
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize
|
||||
from extra.models import resnet
|
||||
from extra.lr_scheduler import LambdaLR
|
||||
@@ -427,14 +427,15 @@ def train_retinanet():
|
||||
|
||||
# ** model initializers **
|
||||
resnet.BatchNorm = FrozenBatchNorm2d
|
||||
resnet.Conv2d = Conv2dHeNormal # NOTE: overriding to support float32 weights when training float16
|
||||
|
||||
retinanet.ConvHead = Conv2dNormal
|
||||
retinanet.ConvClassificationHeadLogits = functools.partial(Conv2dNormal, prior_prob=0.01)
|
||||
retinanet.ConvFPN = Conv2dKaimingUniform
|
||||
|
||||
# ** model setup **
|
||||
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
|
||||
# TODO: Figure out if casting to float16 should be done
|
||||
# backbone.load_from_pretrained()
|
||||
backbone.load_from_pretrained()
|
||||
_freeze_backbone_layers(backbone, 3)
|
||||
|
||||
model = retinanet.RetinaNet(backbone, num_classes=NUM_CLASSES)
|
||||
|
||||
Reference in New Issue
Block a user