mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
clean up how preprocessed folder is defined (#5813)
This commit is contained in:
@@ -12,7 +12,8 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
BASEDIR = Path(__file__).parent / "kits19" / "data"
|
||||
PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed"
|
||||
TRAIN_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "train"
|
||||
VAL_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "val"
|
||||
|
||||
"""
|
||||
To download the dataset:
|
||||
@@ -73,24 +74,22 @@ def preprocess(file_path):
|
||||
return image, label
|
||||
|
||||
def preprocess_dataset(filenames, preprocessed_dir, val):
|
||||
preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
|
||||
if not preprocessed_dataset_dir.is_dir(): os.makedirs(preprocessed_dataset_dir)
|
||||
if not preprocessed_dir.is_dir(): os.makedirs(preprocessed_dir)
|
||||
for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"):
|
||||
case = os.path.basename(fn)
|
||||
image, label = preprocess(fn)
|
||||
image, label = image.astype(np.float32), label.astype(np.uint8)
|
||||
np.save(preprocessed_dataset_dir / f"{case}_x.npy", image, allow_pickle=False)
|
||||
np.save(preprocessed_dataset_dir / f"{case}_y.npy", label, allow_pickle=False)
|
||||
np.save(preprocessed_dir / f"{case}_x.npy", image, allow_pickle=False)
|
||||
np.save(preprocessed_dir / f"{case}_y.npy", label, allow_pickle=False)
|
||||
|
||||
def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1):
|
||||
order = list(range(0, len(files)))
|
||||
preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
|
||||
if shuffle: random.shuffle(order)
|
||||
for i in range(0, len(files), bs):
|
||||
samples = []
|
||||
for i in order[i:i+bs]:
|
||||
if preprocessed_dataset_dir is not None:
|
||||
x_cached_path, y_cached_path = preprocessed_dataset_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dataset_dir / f"{os.path.basename(files[i])}_y.npy"
|
||||
if preprocessed_dir is not None:
|
||||
x_cached_path, y_cached_path = preprocessed_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dir / f"{os.path.basename(files[i])}_y.npy"
|
||||
if x_cached_path.exists() and y_cached_path.exists():
|
||||
samples += [(np.load(x_cached_path), np.load(y_cached_path))]
|
||||
else: samples += [preprocess(files[i])]
|
||||
|
||||
4
test/external/external_test_datasets.py
vendored
4
test/external/external_test_datasets.py
vendored
@@ -66,9 +66,9 @@ class ExternalTestDatasets(unittest.TestCase):
|
||||
np.testing.assert_equal(tinygrad_sample[1][:, 0].numpy(), ref_sample[1])
|
||||
|
||||
def test_kits19_validation_set(self):
|
||||
_, preproc_img_pths, preproc_lbl_pths = self._create_samples(True)
|
||||
preproc_pth, preproc_img_pths, preproc_lbl_pths = self._create_samples(True)
|
||||
ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, True)
|
||||
tinygrad_dataset = self._create_kits19_tinygrad_dataloader(Path(tempfile.gettempdir()), True, use_old_dataloader=True)
|
||||
tinygrad_dataset = self._create_kits19_tinygrad_dataloader(preproc_pth, True, use_old_dataloader=True)
|
||||
|
||||
for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
|
||||
np.testing.assert_equal(tinygrad_sample[0][:, 0], ref_sample[0])
|
||||
|
||||
Reference in New Issue
Block a user