clean up how preprocessed folder is defined (#5813)

This commit is contained in:
Francis Lata
2024-07-30 12:35:26 -04:00
committed by GitHub
parent ca674c31f9
commit ce61be16f1
2 changed files with 9 additions and 10 deletions

View File

@@ -12,7 +12,8 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import fetch
BASEDIR = Path(__file__).parent / "kits19" / "data"
PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed"
TRAIN_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "train"
VAL_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "val"
"""
To download the dataset:
@@ -73,24 +74,22 @@ def preprocess(file_path):
return image, label
def preprocess_dataset(filenames, preprocessed_dir, val):
preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
if not preprocessed_dataset_dir.is_dir(): os.makedirs(preprocessed_dataset_dir)
if not preprocessed_dir.is_dir(): os.makedirs(preprocessed_dir)
for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"):
case = os.path.basename(fn)
image, label = preprocess(fn)
image, label = image.astype(np.float32), label.astype(np.uint8)
np.save(preprocessed_dataset_dir / f"{case}_x.npy", image, allow_pickle=False)
np.save(preprocessed_dataset_dir / f"{case}_y.npy", label, allow_pickle=False)
np.save(preprocessed_dir / f"{case}_x.npy", image, allow_pickle=False)
np.save(preprocessed_dir / f"{case}_y.npy", label, allow_pickle=False)
def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1):
order = list(range(0, len(files)))
preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
if shuffle: random.shuffle(order)
for i in range(0, len(files), bs):
samples = []
for i in order[i:i+bs]:
if preprocessed_dataset_dir is not None:
x_cached_path, y_cached_path = preprocessed_dataset_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dataset_dir / f"{os.path.basename(files[i])}_y.npy"
if preprocessed_dir is not None:
x_cached_path, y_cached_path = preprocessed_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dir / f"{os.path.basename(files[i])}_y.npy"
if x_cached_path.exists() and y_cached_path.exists():
samples += [(np.load(x_cached_path), np.load(y_cached_path))]
else: samples += [preprocess(files[i])]

View File

@@ -66,9 +66,9 @@ class ExternalTestDatasets(unittest.TestCase):
np.testing.assert_equal(tinygrad_sample[1][:, 0].numpy(), ref_sample[1])
def test_kits19_validation_set(self):
_, preproc_img_pths, preproc_lbl_pths = self._create_samples(True)
preproc_pth, preproc_img_pths, preproc_lbl_pths = self._create_samples(True)
ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, True)
tinygrad_dataset = self._create_kits19_tinygrad_dataloader(Path(tempfile.gettempdir()), True, use_old_dataloader=True)
tinygrad_dataset = self._create_kits19_tinygrad_dataloader(preproc_pth, True, use_old_dataloader=True)
for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
np.testing.assert_equal(tinygrad_sample[0][:, 0], ref_sample[0])