Files
tinygrad/test/external/external_test_metrics.py
Francis Lata 3644077a42 [MLPerf][UNet3D] Add DICE loss + metrics (#4204)
* add DICE loss and metrics

* update dice to include reference implementation's link

* remove unused imports

* remove unnecessary test file and update pred + label for metrics and losses test

* add tests to CI + add exclusion of mlperf_unet3d

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-04-17 20:09:33 -04:00

20 lines
802 B
Python

from tinygrad import Tensor
from test.external.mlperf_unet3d.dice import DiceScore
from examples.mlperf.metrics import dice_score
import numpy as np
import torch
import unittest
class ExternalTestMetrics(unittest.TestCase):
def _test_metrics(self, tinygrad_metrics, orig_metrics, pred, label):
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).squeeze().numpy()
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
np.testing.assert_equal(tinygrad_metrics_res, orig_metrics_res)
def test_dice(self):
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
self._test_metrics(dice_score, DiceScore(), pred, label)
if __name__ == '__main__':
unittest.main()