mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
RetinaNet losses (#9536)
* add sigmoid_focal_loss and l1_loss * update ref implementation comment
This commit is contained in:
@@ -1,6 +1,29 @@
|
||||
from examples.mlperf.metrics import dice_score
|
||||
from tinygrad import Tensor
|
||||
|
||||
def dice_ce_loss(pred, tgt):
|
||||
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
|
||||
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
|
||||
return (dice + ce) / 2
|
||||
|
||||
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float=0.25, gamma:float=2.0, reduction:str="none") -> Tensor:
|
||||
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
|
||||
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
|
||||
p_t = p * tgt + (1 - p) * (1 - tgt)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
|
||||
loss = loss * alpha_t
|
||||
|
||||
if reduction == "mean": loss = loss.mean()
|
||||
elif reduction == "sum": loss = loss.sum()
|
||||
return loss
|
||||
|
||||
def l1_loss(pred:Tensor, tgt:Tensor, reduction:str="none") -> Tensor:
|
||||
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
|
||||
loss = (pred - tgt).abs()
|
||||
|
||||
if reduction == "mean": loss = loss.mean()
|
||||
elif reduction == "sum": loss = loss.sum()
|
||||
return loss
|
||||
34
test/external/external_test_losses.py
vendored
34
test/external/external_test_losses.py
vendored
@@ -1,20 +1,40 @@
|
||||
from tinygrad import Tensor
|
||||
from test.external.mlperf_retinanet.focal_loss import sigmoid_focal_loss as ref_sigmoid_focal_loss
|
||||
from test.external.mlperf_unet3d.dice import DiceCELoss
|
||||
from examples.mlperf.losses import dice_ce_loss
|
||||
from examples.mlperf.losses import dice_ce_loss, sigmoid_focal_loss, l1_loss
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
class ExternalTestLosses(unittest.TestCase):
|
||||
def _test_losses(self, tinygrad_metrics, orig_metrics, pred, label):
|
||||
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).numpy()
|
||||
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
|
||||
np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=1e-4)
|
||||
def setUp(self):
|
||||
np.random.seed(42)
|
||||
|
||||
def test_dice_ce(self):
|
||||
def _assert_loss(self, pred, tgt, tinygrad_metrics, ref_metrics, rtol=1e-07, atol=0, **kwargs):
|
||||
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(tgt), **kwargs)
|
||||
ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs)
|
||||
np.testing.assert_allclose(tinygrad_metrics_res.numpy(), ref_metrics_res.numpy(), rtol=rtol, atol=atol)
|
||||
|
||||
def test_dice_ce_loss(self):
|
||||
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
|
||||
self._test_losses(dice_ce_loss, DiceCELoss(True, True, "NCDHW", False), pred, label)
|
||||
tinygrad_metrics_res, ref_metrics_res = dice_ce_loss, DiceCELoss(True, True, "NCDHW", False)
|
||||
self._assert_loss(pred, label, tinygrad_metrics_res, ref_metrics_res, atol=1e-4)
|
||||
|
||||
def test_sigmoid_focal_loss(self):
|
||||
def _apply_logit(p): return np.log(p / (1 - p))
|
||||
pred, tgt = _apply_logit(np.random.rand(5,2).astype(np.float32)), np.random.randint(0, 2, (5, 2)).astype(np.float32)
|
||||
for reduction in ["mean", "sum", "none"]:
|
||||
for alpha, gamma in zip([-1, 0.58], [0, 2]):
|
||||
self._assert_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, rtol=1e-4, alpha=alpha, gamma=gamma, reduction=reduction)
|
||||
|
||||
def test_l1_loss(self):
|
||||
N, C, H, W = 3, 4, 5, 6
|
||||
shapes = ((N, C), (N, C, H), (N, C, H, W))
|
||||
for reduction in ["mean", "sum", "none"]:
|
||||
for shape in shapes:
|
||||
pred, tgt = np.random.randint(shape).astype(np.float32), np.random.randint(shape)
|
||||
self._assert_loss(pred, tgt, l1_loss, torch.nn.functional.l1_loss, reduction=reduction)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
51
test/external/mlperf_retinanet/focal_loss.py
vendored
Normal file
51
test/external/mlperf_retinanet/focal_loss.py
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copied from https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/focal_loss.py
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def sigmoid_focal_loss(
|
||||
inputs: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
alpha: float = 0.25,
|
||||
gamma: float = 2,
|
||||
reduction: str = "none",
|
||||
):
|
||||
"""
|
||||
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
|
||||
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||
|
||||
Args:
|
||||
inputs: A float tensor of arbitrary shape.
|
||||
The predictions for each example.
|
||||
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||
classification label for each element in inputs
|
||||
(0 for the negative class and 1 for the positive class).
|
||||
alpha: (optional) Weighting factor in range (0,1) to balance
|
||||
positive vs negative examples or -1 for ignore. Default = 0.25
|
||||
gamma: Exponent of the modulating factor (1 - p_t) to
|
||||
balance easy vs hard examples.
|
||||
reduction: 'none' | 'mean' | 'sum'
|
||||
'none': No reduction will be applied to the output.
|
||||
'mean': The output will be averaged.
|
||||
'sum': The output will be summed.
|
||||
Returns:
|
||||
Loss tensor with the reduction option applied.
|
||||
"""
|
||||
p = torch.sigmoid(inputs)
|
||||
ce_loss = F.binary_cross_entropy_with_logits(
|
||||
inputs, targets, reduction="none"
|
||||
)
|
||||
p_t = p * targets + (1 - p) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
|
||||
if reduction == "mean":
|
||||
loss = loss.mean()
|
||||
elif reduction == "sum":
|
||||
loss = loss.sum()
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user