mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-10 14:45:35 -05:00
* add DICE loss and metrics * update dice to include reference implementation's link * remove unused imports * remove unnecessary test file and update pred + label for metrics and losses test * add tests to CI + add exclusion of mlperf_unet3d --------- Co-authored-by: chenyu <chenyu@fastmail.com>
20 lines
837 B
Python
20 lines
837 B
Python
from tinygrad import Tensor
|
|
from test.external.mlperf_unet3d.dice import DiceCELoss
|
|
from examples.mlperf.losses import dice_ce_loss
|
|
|
|
import numpy as np
|
|
import torch
|
|
import unittest
|
|
|
|
class ExternalTestLosses(unittest.TestCase):
|
|
def _test_losses(self, tinygrad_metrics, orig_metrics, pred, label):
|
|
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).numpy()
|
|
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
|
|
np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=1e-4)
|
|
|
|
def test_dice_ce(self):
|
|
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
|
|
self._test_losses(dice_ce_loss, DiceCELoss(True, True, "NCDHW", False), pred, label)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |