diff --git a/test/external/external_test_losses.py b/test/external/external_test_losses.py index 3ed8b334a9..cfd2e39efb 100644 --- a/test/external/external_test_losses.py +++ b/test/external/external_test_losses.py @@ -1,7 +1,7 @@ from tinygrad import Tensor from test.external.mlperf_unet3d.dice import DiceCELoss from test.external.mlperf_retinanet.focal_loss import sigmoid_focal_loss as ref_sigmoid_focal_loss -from examples.mlperf.losses import dice_ce_loss, sigmoid_focal_loss +from examples.mlperf.losses import dice_ce_loss, sigmoid_focal_loss, l1_loss import numpy as np import torch @@ -40,6 +40,18 @@ class TestSigmoidFocalLoss(TestLoss): pred, tgt = self._generate_samples() self.assert_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, rtol=1e-4, alpha=0.58, gamma=2, reduction=reduction) +class TestL1Loss(TestLoss): + def _generate_samples(self, shape): + return np.random.randint(shape).astype(np.float32), np.random.randint(shape) + + def test_loss(self): + N, C, H, W = 3, 4, 5, 6 + shapes = ((N, C), (N, C, H), (N, C, H, W)) + + for reduction in ["mean", "sum", "none"]: + for shape in shapes: + pred, tgt = self._generate_samples(shape) + self.assert_loss(pred, tgt, l1_loss, torch.nn.functional.l1_loss, reduction=reduction) if __name__ == '__main__': unittest.main() \ No newline at end of file