From 8cbe4009fcd46e7a0ec6ef17b470e1554a295993 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Fri, 21 Mar 2025 15:52:54 -0400 Subject: [PATCH] RetinaNet losses (#9536) * add sigmoid_focal_loss and l1_loss * update ref implementation comment --- examples/mlperf/losses.py | 23 +++++++++ test/external/external_test_losses.py | 34 ++++++++++--- test/external/mlperf_retinanet/focal_loss.py | 51 ++++++++++++++++++++ 3 files changed, 101 insertions(+), 7 deletions(-) create mode 100644 test/external/mlperf_retinanet/focal_loss.py diff --git a/examples/mlperf/losses.py b/examples/mlperf/losses.py index a7025a0eb5..d20aac063f 100644 --- a/examples/mlperf/losses.py +++ b/examples/mlperf/losses.py @@ -1,6 +1,29 @@ from examples.mlperf.metrics import dice_score +from tinygrad import Tensor def dice_ce_loss(pred, tgt): ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1)) dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean() return (dice + ce) / 2 + +def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float=0.25, gamma:float=2.0, reduction:str="none") -> Tensor: + assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}" + p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none") + p_t = p * tgt + (1 - p) * (1 - tgt) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt) + loss = loss * alpha_t + + if reduction == "mean": loss = loss.mean() + elif reduction == "sum": loss = loss.sum() + return loss + +def l1_loss(pred:Tensor, tgt:Tensor, reduction:str="none") -> Tensor: + assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}" + loss = (pred - tgt).abs() + + if reduction == "mean": loss = loss.mean() + elif reduction == "sum": loss = loss.sum() + return loss \ No newline at end of file diff --git a/test/external/external_test_losses.py b/test/external/external_test_losses.py index 6775f595fd..cd8efde21c 100644 --- a/test/external/external_test_losses.py +++ b/test/external/external_test_losses.py @@ -1,20 +1,40 @@ from tinygrad import Tensor +from test.external.mlperf_retinanet.focal_loss import sigmoid_focal_loss as ref_sigmoid_focal_loss from test.external.mlperf_unet3d.dice import DiceCELoss -from examples.mlperf.losses import dice_ce_loss +from examples.mlperf.losses import dice_ce_loss, sigmoid_focal_loss, l1_loss import numpy as np import torch import unittest class ExternalTestLosses(unittest.TestCase): - def _test_losses(self, tinygrad_metrics, orig_metrics, pred, label): - tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).numpy() - orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy() - np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=1e-4) + def setUp(self): + np.random.seed(42) - def test_dice_ce(self): + def _assert_loss(self, pred, tgt, tinygrad_metrics, ref_metrics, rtol=1e-07, atol=0, **kwargs): + tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(tgt), **kwargs) + ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs) + np.testing.assert_allclose(tinygrad_metrics_res.numpy(), ref_metrics_res.numpy(), rtol=rtol, atol=atol) + + def test_dice_ce_loss(self): pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8) - self._test_losses(dice_ce_loss, DiceCELoss(True, True, "NCDHW", False), pred, label) + tinygrad_metrics_res, ref_metrics_res = dice_ce_loss, DiceCELoss(True, True, "NCDHW", False) + self._assert_loss(pred, label, tinygrad_metrics_res, ref_metrics_res, atol=1e-4) + + def test_sigmoid_focal_loss(self): + def _apply_logit(p): return np.log(p / (1 - p)) + pred, tgt = _apply_logit(np.random.rand(5,2).astype(np.float32)), np.random.randint(0, 2, (5, 2)).astype(np.float32) + for reduction in ["mean", "sum", "none"]: + for alpha, gamma in zip([-1, 0.58], [0, 2]): + self._assert_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, rtol=1e-4, alpha=alpha, gamma=gamma, reduction=reduction) + + def test_l1_loss(self): + N, C, H, W = 3, 4, 5, 6 + shapes = ((N, C), (N, C, H), (N, C, H, W)) + for reduction in ["mean", "sum", "none"]: + for shape in shapes: + pred, tgt = np.random.randint(shape).astype(np.float32), np.random.randint(shape) + self._assert_loss(pred, tgt, l1_loss, torch.nn.functional.l1_loss, reduction=reduction) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/external/mlperf_retinanet/focal_loss.py b/test/external/mlperf_retinanet/focal_loss.py new file mode 100644 index 0000000000..7167b58a1d --- /dev/null +++ b/test/external/mlperf_retinanet/focal_loss.py @@ -0,0 +1,51 @@ +# Copied from https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/focal_loss.py + +import torch +import torch.nn.functional as F + + +def sigmoid_focal_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2, + reduction: str = "none", +): + """ + Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py . + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples or -1 for ignore. Default = 0.25 + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + Returns: + Loss tensor with the reduction option applied. + """ + p = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits( + inputs, targets, reduction="none" + ) + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + + return loss \ No newline at end of file