Files
tinygrad/test/external/external_test_datasets.py
Francis Lata bb849a57d1 [MLPerf] UNet3D dataloader (#4343)
* add support for train/val datasets for kits19

* split dataset into train and val sets

* add tests for kits19 dataloader

* add MLPerf dataset tests to CI

* update unet3d model_eval script

* fix linting

* add nibabel

* fix how mock dataset gets created

* update ref implementation with permalink and no edits

* clean up test and update rand_flip implementation

* cleanups
2024-04-28 22:34:18 -04:00

72 lines
2.4 KiB
Python

from extra.datasets.kits19 import iterate, preprocess
from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
from tinygrad.helpers import temp
from pathlib import Path
import nibabel as nib
import numpy as np
import os
import random
import tempfile
import unittest
class ExternalTestDatasets(unittest.TestCase):
def _set_seed(self):
np.random.seed(42)
random.seed(42)
def _create_sample(self):
self._set_seed()
img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8)
img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4))
os.makedirs(tempfile.gettempdir() + "/case_0000", exist_ok=True)
nib.save(img, temp("case_0000/imaging.nii.gz"))
nib.save(lbl, temp("case_0000/segmentation.nii.gz"))
return Path(tempfile.gettempdir()) / "case_0000"
def _create_kits19_ref_sample(self, sample_pth, val):
self._set_seed()
img, lbl = preprocess(sample_pth)
dataset = "val" if val else "train"
preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_0000_x.npy"), temp(f"{dataset}/case_0000_y.npy")
os.makedirs(tempfile.gettempdir() + f"/{dataset}", exist_ok=True)
np.save(preproc_img_pth, img, allow_pickle=False)
np.save(preproc_lbl_pth, lbl, allow_pickle=False)
if val:
dataset = PytVal([preproc_img_pth], [preproc_lbl_pth])
else:
dataset = PytTrain([preproc_img_pth], [preproc_lbl_pth], patch_size=(128, 128, 128), oversampling=0.4)
return dataset[0]
def _create_kits19_tinygrad_sample(self, sample_pth, val):
self._set_seed()
return next(iterate([sample_pth], val=val))
def test_kits19_training_set(self):
sample_pth = self._create_sample()
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, False)
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, False)
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
np.testing.assert_equal(tinygrad_lbl[:, 0], ref_lbl)
def test_kits19_validation_set(self):
sample_pth = self._create_sample()
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, True)
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, True)
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
np.testing.assert_equal(tinygrad_lbl, ref_lbl)
if __name__ == '__main__':
unittest.main()