RetinaNet losses (#9536)

* add sigmoid_focal_loss and l1_loss

* update ref implementation comment
This commit is contained in:
Francis Lata
2025-03-21 15:52:54 -04:00
committed by GitHub
parent e6389184c5
commit 8cbe4009fc
3 changed files with 101 additions and 7 deletions

View File

@@ -1,6 +1,29 @@
from examples.mlperf.metrics import dice_score
from tinygrad import Tensor
def dice_ce_loss(pred, tgt):
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
return (dice + ce) / 2
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float=0.25, gamma:float=2.0, reduction:str="none") -> Tensor:
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
p_t = p * tgt + (1 - p) * (1 - tgt)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
loss = loss * alpha_t
if reduction == "mean": loss = loss.mean()
elif reduction == "sum": loss = loss.sum()
return loss
def l1_loss(pred:Tensor, tgt:Tensor, reduction:str="none") -> Tensor:
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
loss = (pred - tgt).abs()
if reduction == "mean": loss = loss.mean()
elif reduction == "sum": loss = loss.sum()
return loss

View File

@@ -1,20 +1,40 @@
from tinygrad import Tensor
from test.external.mlperf_retinanet.focal_loss import sigmoid_focal_loss as ref_sigmoid_focal_loss
from test.external.mlperf_unet3d.dice import DiceCELoss
from examples.mlperf.losses import dice_ce_loss
from examples.mlperf.losses import dice_ce_loss, sigmoid_focal_loss, l1_loss
import numpy as np
import torch
import unittest
class ExternalTestLosses(unittest.TestCase):
def _test_losses(self, tinygrad_metrics, orig_metrics, pred, label):
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).numpy()
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=1e-4)
def setUp(self):
np.random.seed(42)
def test_dice_ce(self):
def _assert_loss(self, pred, tgt, tinygrad_metrics, ref_metrics, rtol=1e-07, atol=0, **kwargs):
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(tgt), **kwargs)
ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs)
np.testing.assert_allclose(tinygrad_metrics_res.numpy(), ref_metrics_res.numpy(), rtol=rtol, atol=atol)
def test_dice_ce_loss(self):
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
self._test_losses(dice_ce_loss, DiceCELoss(True, True, "NCDHW", False), pred, label)
tinygrad_metrics_res, ref_metrics_res = dice_ce_loss, DiceCELoss(True, True, "NCDHW", False)
self._assert_loss(pred, label, tinygrad_metrics_res, ref_metrics_res, atol=1e-4)
def test_sigmoid_focal_loss(self):
def _apply_logit(p): return np.log(p / (1 - p))
pred, tgt = _apply_logit(np.random.rand(5,2).astype(np.float32)), np.random.randint(0, 2, (5, 2)).astype(np.float32)
for reduction in ["mean", "sum", "none"]:
for alpha, gamma in zip([-1, 0.58], [0, 2]):
self._assert_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, rtol=1e-4, alpha=alpha, gamma=gamma, reduction=reduction)
def test_l1_loss(self):
N, C, H, W = 3, 4, 5, 6
shapes = ((N, C), (N, C, H), (N, C, H, W))
for reduction in ["mean", "sum", "none"]:
for shape in shapes:
pred, tgt = np.random.randint(shape).astype(np.float32), np.random.randint(shape)
self._assert_loss(pred, tgt, l1_loss, torch.nn.functional.l1_loss, reduction=reduction)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,51 @@
# Copied from https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/focal_loss.py
import torch
import torch.nn.functional as F
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2,
reduction: str = "none",
):
"""
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default = 0.25
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Loss tensor with the reduction option applied.
"""
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(
inputs, targets, reduction="none"
)
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss