mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add missing test and cleanup focal loss
This commit is contained in:
@@ -22,9 +22,9 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor] = None, al
|
||||
|
||||
if reduction == "mean":
|
||||
if mask is not None:
|
||||
loss = (avg_loss := loss.sum(0).mean()) / (Tensor.ones_like(avg_loss) if mask is None else mask.sum())
|
||||
loss = loss.sum(axis=0).mean() / mask.sum()
|
||||
else:
|
||||
loss = loss.mean()
|
||||
elif reduction == "sum": loss = loss.sum(0).sum()
|
||||
elif reduction == "sum": loss = loss.sum(axis=0).sum()
|
||||
|
||||
return loss
|
||||
|
||||
7
test/external/external_test_losses.py
vendored
7
test/external/external_test_losses.py
vendored
@@ -30,7 +30,7 @@ class TestSigmoidFocalLoss(unittest.TestCase):
|
||||
ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs)
|
||||
|
||||
# NOTE: since boolean indexing is not supported in tinygrad, compare the sum instead.
|
||||
np.testing.assert_allclose(tinygrad_metrics_res.sum().numpy(), ref_metrics_res.sum().numpy(), atol=1e-5)
|
||||
np.testing.assert_allclose(tinygrad_metrics_res.sum().numpy(), ref_metrics_res.sum().numpy(), rtol=1e-6)
|
||||
|
||||
def _generate_samples(self, generate_mask=False):
|
||||
def _apply_logit(p): return np.log(p / (1 - p))
|
||||
@@ -51,5 +51,10 @@ class TestSigmoidFocalLoss(unittest.TestCase):
|
||||
pred, tgt, _ = self._generate_samples()
|
||||
self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, alpha=0.58, gamma=2, reduction=reduction)
|
||||
|
||||
def test_loss_correct_ratio_mask(self):
|
||||
for reduction in ["mean", "sum", "none"]:
|
||||
pred, tgt, mask = self._generate_samples(generate_mask=True)
|
||||
self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, mask=mask, alpha=0.58, gamma=2, reduction=reduction)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user