mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add tests for l1_loss
This commit is contained in:
14
test/external/external_test_losses.py
vendored
14
test/external/external_test_losses.py
vendored
@@ -1,7 +1,7 @@
|
||||
from tinygrad import Tensor
|
||||
from test.external.mlperf_unet3d.dice import DiceCELoss
|
||||
from test.external.mlperf_retinanet.focal_loss import sigmoid_focal_loss as ref_sigmoid_focal_loss
|
||||
from examples.mlperf.losses import dice_ce_loss, sigmoid_focal_loss
|
||||
from examples.mlperf.losses import dice_ce_loss, sigmoid_focal_loss, l1_loss
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -40,6 +40,18 @@ class TestSigmoidFocalLoss(TestLoss):
|
||||
pred, tgt = self._generate_samples()
|
||||
self.assert_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, rtol=1e-4, alpha=0.58, gamma=2, reduction=reduction)
|
||||
|
||||
class TestL1Loss(TestLoss):
|
||||
def _generate_samples(self, shape):
|
||||
return np.random.randint(shape).astype(np.float32), np.random.randint(shape)
|
||||
|
||||
def test_loss(self):
|
||||
N, C, H, W = 3, 4, 5, 6
|
||||
shapes = ((N, C), (N, C, H), (N, C, H, W))
|
||||
|
||||
for reduction in ["mean", "sum", "none"]:
|
||||
for shape in shapes:
|
||||
pred, tgt = self._generate_samples(shape)
|
||||
self.assert_loss(pred, tgt, l1_loss, torch.nn.functional.l1_loss, reduction=reduction)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user