mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
RetinaNet losses (#9536)
* add sigmoid_focal_loss and l1_loss * update ref implementation comment
This commit is contained in:
@@ -1,6 +1,29 @@
|
|||||||
from examples.mlperf.metrics import dice_score
|
from examples.mlperf.metrics import dice_score
|
||||||
|
from tinygrad import Tensor
|
||||||
|
|
||||||
def dice_ce_loss(pred, tgt):
|
def dice_ce_loss(pred, tgt):
|
||||||
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
|
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
|
||||||
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
|
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
|
||||||
return (dice + ce) / 2
|
return (dice + ce) / 2
|
||||||
|
|
||||||
|
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float=0.25, gamma:float=2.0, reduction:str="none") -> Tensor:
|
||||||
|
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
|
||||||
|
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
|
||||||
|
p_t = p * tgt + (1 - p) * (1 - tgt)
|
||||||
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||||
|
|
||||||
|
if alpha >= 0:
|
||||||
|
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
|
||||||
|
loss = loss * alpha_t
|
||||||
|
|
||||||
|
if reduction == "mean": loss = loss.mean()
|
||||||
|
elif reduction == "sum": loss = loss.sum()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def l1_loss(pred:Tensor, tgt:Tensor, reduction:str="none") -> Tensor:
|
||||||
|
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
|
||||||
|
loss = (pred - tgt).abs()
|
||||||
|
|
||||||
|
if reduction == "mean": loss = loss.mean()
|
||||||
|
elif reduction == "sum": loss = loss.sum()
|
||||||
|
return loss
|
||||||
34
test/external/external_test_losses.py
vendored
34
test/external/external_test_losses.py
vendored
@@ -1,20 +1,40 @@
|
|||||||
from tinygrad import Tensor
|
from tinygrad import Tensor
|
||||||
|
from test.external.mlperf_retinanet.focal_loss import sigmoid_focal_loss as ref_sigmoid_focal_loss
|
||||||
from test.external.mlperf_unet3d.dice import DiceCELoss
|
from test.external.mlperf_unet3d.dice import DiceCELoss
|
||||||
from examples.mlperf.losses import dice_ce_loss
|
from examples.mlperf.losses import dice_ce_loss, sigmoid_focal_loss, l1_loss
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
class ExternalTestLosses(unittest.TestCase):
|
class ExternalTestLosses(unittest.TestCase):
|
||||||
def _test_losses(self, tinygrad_metrics, orig_metrics, pred, label):
|
def setUp(self):
|
||||||
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).numpy()
|
np.random.seed(42)
|
||||||
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
|
|
||||||
np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=1e-4)
|
|
||||||
|
|
||||||
def test_dice_ce(self):
|
def _assert_loss(self, pred, tgt, tinygrad_metrics, ref_metrics, rtol=1e-07, atol=0, **kwargs):
|
||||||
|
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(tgt), **kwargs)
|
||||||
|
ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs)
|
||||||
|
np.testing.assert_allclose(tinygrad_metrics_res.numpy(), ref_metrics_res.numpy(), rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
def test_dice_ce_loss(self):
|
||||||
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
|
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
|
||||||
self._test_losses(dice_ce_loss, DiceCELoss(True, True, "NCDHW", False), pred, label)
|
tinygrad_metrics_res, ref_metrics_res = dice_ce_loss, DiceCELoss(True, True, "NCDHW", False)
|
||||||
|
self._assert_loss(pred, label, tinygrad_metrics_res, ref_metrics_res, atol=1e-4)
|
||||||
|
|
||||||
|
def test_sigmoid_focal_loss(self):
|
||||||
|
def _apply_logit(p): return np.log(p / (1 - p))
|
||||||
|
pred, tgt = _apply_logit(np.random.rand(5,2).astype(np.float32)), np.random.randint(0, 2, (5, 2)).astype(np.float32)
|
||||||
|
for reduction in ["mean", "sum", "none"]:
|
||||||
|
for alpha, gamma in zip([-1, 0.58], [0, 2]):
|
||||||
|
self._assert_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, rtol=1e-4, alpha=alpha, gamma=gamma, reduction=reduction)
|
||||||
|
|
||||||
|
def test_l1_loss(self):
|
||||||
|
N, C, H, W = 3, 4, 5, 6
|
||||||
|
shapes = ((N, C), (N, C, H), (N, C, H, W))
|
||||||
|
for reduction in ["mean", "sum", "none"]:
|
||||||
|
for shape in shapes:
|
||||||
|
pred, tgt = np.random.randint(shape).astype(np.float32), np.random.randint(shape)
|
||||||
|
self._assert_loss(pred, tgt, l1_loss, torch.nn.functional.l1_loss, reduction=reduction)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
51
test/external/mlperf_retinanet/focal_loss.py
vendored
Normal file
51
test/external/mlperf_retinanet/focal_loss.py
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# Copied from https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/focal_loss.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid_focal_loss(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
alpha: float = 0.25,
|
||||||
|
gamma: float = 2,
|
||||||
|
reduction: str = "none",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
|
||||||
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A float tensor of arbitrary shape.
|
||||||
|
The predictions for each example.
|
||||||
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
||||||
|
classification label for each element in inputs
|
||||||
|
(0 for the negative class and 1 for the positive class).
|
||||||
|
alpha: (optional) Weighting factor in range (0,1) to balance
|
||||||
|
positive vs negative examples or -1 for ignore. Default = 0.25
|
||||||
|
gamma: Exponent of the modulating factor (1 - p_t) to
|
||||||
|
balance easy vs hard examples.
|
||||||
|
reduction: 'none' | 'mean' | 'sum'
|
||||||
|
'none': No reduction will be applied to the output.
|
||||||
|
'mean': The output will be averaged.
|
||||||
|
'sum': The output will be summed.
|
||||||
|
Returns:
|
||||||
|
Loss tensor with the reduction option applied.
|
||||||
|
"""
|
||||||
|
p = torch.sigmoid(inputs)
|
||||||
|
ce_loss = F.binary_cross_entropy_with_logits(
|
||||||
|
inputs, targets, reduction="none"
|
||||||
|
)
|
||||||
|
p_t = p * targets + (1 - p) * (1 - targets)
|
||||||
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||||
|
|
||||||
|
if alpha >= 0:
|
||||||
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||||
|
loss = alpha_t * loss
|
||||||
|
|
||||||
|
if reduction == "mean":
|
||||||
|
loss = loss.mean()
|
||||||
|
elif reduction == "sum":
|
||||||
|
loss = loss.sum()
|
||||||
|
|
||||||
|
return loss
|
||||||
Reference in New Issue
Block a user