mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
* add support for train/val datasets for kits19 * split dataset into train and val sets * add tests for kits19 dataloader * add MLPerf dataset tests to CI * update unet3d model_eval script * fix linting * add nibabel * fix how mock dataset gets created * update ref implementation with permalink and no edits * clean up test and update rand_flip implementation * cleanups
72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
from extra.datasets.kits19 import iterate, preprocess
|
|
from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
|
|
from tinygrad.helpers import temp
|
|
from pathlib import Path
|
|
|
|
import nibabel as nib
|
|
import numpy as np
|
|
import os
|
|
import random
|
|
import tempfile
|
|
import unittest
|
|
|
|
class ExternalTestDatasets(unittest.TestCase):
|
|
def _set_seed(self):
|
|
np.random.seed(42)
|
|
random.seed(42)
|
|
|
|
def _create_sample(self):
|
|
self._set_seed()
|
|
|
|
img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8)
|
|
img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4))
|
|
|
|
os.makedirs(tempfile.gettempdir() + "/case_0000", exist_ok=True)
|
|
nib.save(img, temp("case_0000/imaging.nii.gz"))
|
|
nib.save(lbl, temp("case_0000/segmentation.nii.gz"))
|
|
|
|
return Path(tempfile.gettempdir()) / "case_0000"
|
|
|
|
def _create_kits19_ref_sample(self, sample_pth, val):
|
|
self._set_seed()
|
|
|
|
img, lbl = preprocess(sample_pth)
|
|
dataset = "val" if val else "train"
|
|
preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_0000_x.npy"), temp(f"{dataset}/case_0000_y.npy")
|
|
|
|
os.makedirs(tempfile.gettempdir() + f"/{dataset}", exist_ok=True)
|
|
np.save(preproc_img_pth, img, allow_pickle=False)
|
|
np.save(preproc_lbl_pth, lbl, allow_pickle=False)
|
|
|
|
if val:
|
|
dataset = PytVal([preproc_img_pth], [preproc_lbl_pth])
|
|
else:
|
|
dataset = PytTrain([preproc_img_pth], [preproc_lbl_pth], patch_size=(128, 128, 128), oversampling=0.4)
|
|
|
|
return dataset[0]
|
|
|
|
def _create_kits19_tinygrad_sample(self, sample_pth, val):
|
|
self._set_seed()
|
|
return next(iterate([sample_pth], val=val))
|
|
|
|
def test_kits19_training_set(self):
|
|
sample_pth = self._create_sample()
|
|
|
|
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, False)
|
|
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, False)
|
|
|
|
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
|
|
np.testing.assert_equal(tinygrad_lbl[:, 0], ref_lbl)
|
|
|
|
def test_kits19_validation_set(self):
|
|
sample_pth = self._create_sample()
|
|
|
|
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, True)
|
|
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, True)
|
|
|
|
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
|
|
np.testing.assert_equal(tinygrad_lbl, ref_lbl)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|