From 64ee757076db129f3b98a6694827db8d88d8ff1f Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Fri, 7 Mar 2025 16:33:03 -0800 Subject: [PATCH] initial work to support f16 --- examples/mlperf/model_train.py | 15 ++++++++++----- extra/models/retinanet.py | 7 +++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index bd594e1734..7d925c656e 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -391,13 +391,15 @@ def train_retinanet(): return LambdaLR(optim, _lr_lambda) @TinyJit - def _train_step(model, optim, lr_scheduler, x, **kwargs): + def _train_step(model, optim, lr_scheduler, loss_scaler, x, **kwargs): optim.zero_grad() losses = model(normalize(x, GPUS), **kwargs) - loss = sum([l for l in losses.values()]) + loss = (sum([l for l in losses.values()]) * loss_scaler) loss.backward() + for t in optim.params: t.grad = t.grad.contiguous() / loss_scaler + optim.step() lr_scheduler.step() @@ -410,7 +412,7 @@ def train_retinanet(): # ** hyperparameters ** config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1)) - config["bs"] = BS = getenv("BS", 256) + config["bs"] = BS = getenv("BS", 12 * len(GPUS) if dtypes.default_float == dtypes.float16 else 12 * len(GPUS)) # TODO: update float16 to use larger BS config["eval_bs"] = EVAL_BS = getenv("EVAL_BS", BS) config["epochs"] = EPOCHS = getenv("EPOCHS", 4) config["train_beam"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value) @@ -418,6 +420,8 @@ def train_retinanet(): config["lr"] = lr = getenv("LR", 0.0001 * (BS / 256)) config["lr_warmup_epochs"] = lr_warmup_epochs = getenv("LR_WARMUP_EPOCHS", 1) config["lr_warmup_factor"] = lr_warmup_factor = getenv("LR_WARMUP_FACTOR", 1e-3) + config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 256.0 if dtypes.default_float == dtypes.float16 else 1.0) + config["default_float"] = dtypes.default_float.name if SEED: Tensor.manual_seed(SEED) @@ -426,7 +430,8 @@ def train_retinanet(): # ** model setup ** backbone = resnet.ResNeXt50_32X4D(num_classes=None) - backbone.load_from_pretrained() + # TODO: Figure out if casting to float16 should be done + # backbone.load_from_pretrained() _freeze_backbone_layers(backbone, 3) model = RetinaNet(backbone, num_classes=NUM_CLASSES) @@ -484,7 +489,7 @@ def train_retinanet(): GlobalCounters.reset() x, y_bboxes, y_labels, matches, anchors, proc = proc - loss, losses = _train_step(model, optim, lr_scheduler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes) + loss, losses = _train_step(model, optim, lr_scheduler, loss_scaler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes) pt = time.perf_counter() diff --git a/extra/models/retinanet.py b/extra/models/retinanet.py index 5574e7d077..0f5993412a 100644 --- a/extra/models/retinanet.py +++ b/extra/models/retinanet.py @@ -1,7 +1,6 @@ import math -from tinygrad import Tensor +from tinygrad import Tensor, dtypes from tinygrad.helpers import flatten, get_child -import tinygrad.nn as nn from examples.mlperf.helpers import generate_anchors, BoxCoder from examples.mlperf.initializers import Conv2dNormal, Conv2dKaimingUniform from examples.mlperf.losses import sigmoid_focal_loss, l1_loss @@ -149,7 +148,7 @@ class ClassificationHead: def _compute_loss(self, x:Tensor, labels:Tensor, matches:Tensor) -> Tensor: labels = ((labels + 1) * (fg_idxs := matches >= 0) - 1).one_hot(num_classes=x.shape[-1]) valid_idxs = (matches != -2).reshape(matches.shape[0], -1, 1) - loss = valid_idxs.where(sigmoid_focal_loss(x, labels), 0).sum(-1).sum(-1) + loss = valid_idxs.where(sigmoid_focal_loss(x.cast(dtypes.float32), labels), 0).sum(-1).sum(-1) loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0] return loss @@ -174,7 +173,7 @@ class RegressionHead: def _compute_loss(self, x:Tensor, bboxes:Tensor, matches:Tensor, anchors:Tensor) -> Tensor: mask = (fg_idxs := matches >= 0).reshape(matches.shape[0], -1, 1) - x = x * mask + x = x.cast(dtypes.float32) * mask tgt = self.box_coder.encode(bboxes, anchors) * mask loss = l1_loss(x, tgt).sum(-1).sum(-1) loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0]