update dataloader script example (#5818)

This commit is contained in:
Francis Lata
2024-07-30 15:18:29 -04:00
committed by GitHub
parent eebb1b9922
commit a0baff7a3d

View File

@@ -357,13 +357,13 @@ if __name__ == "__main__":
def load_unet3d(val):
assert not val, "validation set is not supported due to different sizes on inputs"
from extra.datasets.kits19 import get_train_files, get_val_files, preprocess_dataset, BASEDIR
preprocessed_dataset_dir = (BASEDIR / ".." / "preprocessed" / ("val" if val else "train"))
from extra.datasets.kits19 import get_train_files, get_val_files, preprocess_dataset, TRAIN_PREPROCESSED_DIR, VAL_PREPROCESSED_DIR
preprocessed_dir = VAL_PREPROCESSED_DIR if val else TRAIN_PREPROCESSED_DIR
files = get_val_files() if val else get_train_files()
if not preprocessed_dataset_dir.exists(): preprocess_dataset(files, preprocessed_dataset_dir, val)
if not preprocessed_dir.exists(): preprocess_dataset(files, preprocessed_dir, val)
with tqdm(total=len(files)) as pbar:
for x, _, _ in batch_load_unet3d(preprocessed_dataset_dir, val=val):
for x, _, _ in batch_load_unet3d(preprocessed_dir, val=val):
pbar.update(x.shape[0])
def load_resnet(val):