diff --git a/examples/mlperf/initializers.py b/examples/mlperf/initializers.py index dec1550d7b..96b494fd92 100644 --- a/examples/mlperf/initializers.py +++ b/examples/mlperf/initializers.py @@ -68,27 +68,42 @@ class LayerNormBert: return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float)) class FrozenBatchNorm2d(nn.BatchNorm2d): - def __init__(self, num_features:int): + def __init__(self, num_features:int, affine:bool=True, track_running_stats:bool=True): super().__init__(num_features) - self.weight.requires_grad = False - self.bias.requires_grad = False + + if affine: + self.weight = self.weight.cast(dtypes.float32) + self.bias = self.weight.cast(dtypes.float32) + + self.weight.requires_grad = False + self.bias.requires_grad = False + + if track_running_stats: self.running_mean, self.running_var = self.running_mean.cast(dtypes.float32), self.running_var.cast(dtypes.float32) class Conv2dNormal(nn.Conv2d): def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1, bias:bool=True, prior_prob:float|None=None): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) - self.weight = Tensor.normal(*self.weight.shape, std=0.01) + self.weight = Tensor.normal(*self.weight.shape, std=0.01, dtype=dtypes.float32) if bias: if prior_prob: - prior_prob = Tensor(prior_prob, device=self.bias.device, dtype=self.bias.dtype).expand(*self.bias.shape) + prior_prob = Tensor(prior_prob, device=self.bias.device, dtype=dtypes.float32).expand(*self.bias.shape) self.bias = -(((1 - prior_prob) / prior_prob).log()) - else: self.bias = Tensor.zeros_like(self.bias) + else: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32) + + def __call__(self, x:Tensor) -> Tensor: + return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None, + groups=self.groups, stride=self.stride, padding=self.padding) class Conv2dKaimingUniform(nn.Conv2d): def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...], stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1, bias:bool=True): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) - self.weight = Tensor.kaiming_uniform(*self.weight.shape, a=1) - if bias: self.bias = Tensor.zeros_like(self.bias) + self.weight = Tensor.kaiming_uniform(*self.weight.shape, a=1, dtype=dtypes.float32) + if bias: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32) + + def __call__(self, x:Tensor) -> Tensor: + return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None, + groups=self.groups, stride=self.stride, padding=self.padding) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 7d16b46fe0..f9023bd6e7 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -345,7 +345,7 @@ def train_resnet(): def train_retinanet(): from contextlib import redirect_stdout from examples.mlperf.dataloader import batch_load_retinanet - from examples.mlperf.initializers import FrozenBatchNorm2d, Conv2dNormal, Conv2dKaimingUniform + from examples.mlperf.initializers import FrozenBatchNorm2d, Conv2dNormal, Conv2dKaimingUniform, Conv2dHeNormal from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize from extra.models import resnet from extra.lr_scheduler import LambdaLR @@ -427,14 +427,15 @@ def train_retinanet(): # ** model initializers ** resnet.BatchNorm = FrozenBatchNorm2d + resnet.Conv2d = Conv2dHeNormal # NOTE: overriding to support float32 weights when training float16 + retinanet.ConvHead = Conv2dNormal retinanet.ConvClassificationHeadLogits = functools.partial(Conv2dNormal, prior_prob=0.01) retinanet.ConvFPN = Conv2dKaimingUniform # ** model setup ** backbone = resnet.ResNeXt50_32X4D(num_classes=None) - # TODO: Figure out if casting to float16 should be done - # backbone.load_from_pretrained() + backbone.load_from_pretrained() _freeze_backbone_layers(backbone, 3) model = retinanet.RetinaNet(backbone, num_classes=NUM_CLASSES)