From e5bc0c0485057df67043f6194976409e6bb02ad4 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Tue, 10 Dec 2024 16:33:16 +0000 Subject: [PATCH] start some work on classification loss --- examples/mlperf/losses.py | 6 +++++- examples/mlperf/model_train.py | 9 ++++---- extra/models/retinanet.py | 38 +++++++++++++++++++++++++++------- 3 files changed, 40 insertions(+), 13 deletions(-) diff --git a/examples/mlperf/losses.py b/examples/mlperf/losses.py index 38062830c4..ce2e8d9d79 100644 --- a/examples/mlperf/losses.py +++ b/examples/mlperf/losses.py @@ -1,3 +1,5 @@ +from typing import Optional + from examples.mlperf.metrics import dice_score from tinygrad import Tensor @@ -6,9 +8,10 @@ def dice_ce_loss(pred, tgt): dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean() return (dice + ce) / 2 -def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2, reduction:str = "none"): +def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor]=None, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor: assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}" p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none") + if mask: p = p * mask p_t = p * tgt + (1 - p) * (1 - tgt) loss = ce_loss * ((1 - p_t) ** gamma) @@ -16,6 +19,7 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt) loss *= alpha_t + if mask: loss = loss * mask if reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() return loss diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 9a12402727..7e5b3b9c43 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -370,7 +370,7 @@ def train_retinanet(): def _data_get(it): x, y_boxes, y_labels, matches, cookie = next(it) - return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels, matches.shard(GPUS, axis=0), cookie + return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), cookie def _create_lr_scheduler(optim, start_iter, warmup_iters, warmup_factor): # TODO: refactor this a bit more so we don't have to recreate it, unlike what MLPerf script is doing @@ -381,10 +381,11 @@ def train_retinanet(): return warmup_factor * (1 - alpha) + alpha return LambdaLR(optim, _lr_lambda) - def _train_step(model, optim, lr_scheduler, x, matches): + @Tensor.train() + def _train_step(model, optim, lr_scheduler, x, y, matches): optim.zero_grad() - y_hat = model(normalize(x, GPUS)) + y_hat = model(normalize(x, GPUS), y=y, matches=matches) # ** hyperparameters ** # using https://github.com/mlcommons/logging/blob/96d0acee011ba97702532dcc39e6eeaa99ebef24/mlperf_logging/rcp_checker/training_4.1.0/rcps_ssd.json#L3 @@ -435,7 +436,7 @@ def train_retinanet(): while proc is not None: x, y_boxes, y_labels, matches, proc = proc - _train_step(model, optim, lr_scheduler, x, matches) # TODO: enable once full model has been integrated + _train_step(model, optim, lr_scheduler, x, y_labels, matches) # TODO: enable once full model has been integrated if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued try: diff --git a/extra/models/retinanet.py b/extra/models/retinanet.py index 8e188787dc..09c6443de9 100644 --- a/extra/models/retinanet.py +++ b/extra/models/retinanet.py @@ -1,8 +1,11 @@ +from typing import Optional + import math from tinygrad import Tensor, dtypes from tinygrad.helpers import flatten, get_child import tinygrad.nn as nn from examples.mlperf.initializers import Conv2dNormal, Conv2dKaimingUniform +from examples.mlperf.losses import sigmoid_focal_loss from extra.models.helpers import meshgrid, nms from extra.models.resnet import ResNet import numpy as np @@ -47,10 +50,11 @@ class RetinaNet: self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes) self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios) - def __call__(self, x): - return self.forward(x) - def forward(self, x): - return self.head(self.backbone(x)) + def __call__(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None): + return self.forward(x, y=y, matches=matches) + + def forward(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None): + return self.head(self.backbone(x), y=y, matches=matches) def load_from_pretrained(self): model_urls = { @@ -136,9 +140,26 @@ class ClassificationHead: self.num_classes = num_classes self.conv = flatten([(Conv2dNormal(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)]) self.cls_logits = Conv2dNormal(in_channels, num_anchors * num_classes, kernel_size=3, padding=1, prior_prob=prior_prob) - def __call__(self, x): + + def __call__(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None): out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x] - return out[0].cat(*out[1:], dim=1).sigmoid() + out = out[0].cat(*out[1:], dim=1).sigmoid() + + if Tensor.training: + assert y is not None and matches is not None, "y and matches should be passed in when training" + return self._compute_loss(out, y, matches) + else: + return out + + def _compute_loss(self, x:Tensor, y:Tensor, matches:Tensor) -> Tensor: + y = ((y + 1) * matches - 1).one_hot(num_classes=x.shape[-1]) + + # find indices for which anchors should be ignored + valid_idxs = (matches != -2).reshape(matches.shape[0], -1, 1) + + import pdb; pdb.set_trace() + loss = sigmoid_focal_loss(x, y, mask=valid_idxs, reduction="sum") + class RegressionHead: def __init__(self, in_channels, num_anchors): @@ -152,8 +173,9 @@ class RetinaHead: def __init__(self, in_channels, num_anchors, num_classes): self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes) self.regression_head = RegressionHead(in_channels, num_anchors) - def __call__(self, x): - pred_bbox, pred_class = self.regression_head(x), self.classification_head(x) + + def __call__(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None): + pred_bbox, pred_class = self.regression_head(x), self.classification_head(x, y=y, matches=matches) out = pred_bbox.cat(pred_class, dim=-1) return out