mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
update dataloader script example (#5818)
This commit is contained in:
@@ -357,13 +357,13 @@ if __name__ == "__main__":
|
||||
def load_unet3d(val):
|
||||
assert not val, "validation set is not supported due to different sizes on inputs"
|
||||
|
||||
from extra.datasets.kits19 import get_train_files, get_val_files, preprocess_dataset, BASEDIR
|
||||
preprocessed_dataset_dir = (BASEDIR / ".." / "preprocessed" / ("val" if val else "train"))
|
||||
from extra.datasets.kits19 import get_train_files, get_val_files, preprocess_dataset, TRAIN_PREPROCESSED_DIR, VAL_PREPROCESSED_DIR
|
||||
preprocessed_dir = VAL_PREPROCESSED_DIR if val else TRAIN_PREPROCESSED_DIR
|
||||
files = get_val_files() if val else get_train_files()
|
||||
|
||||
if not preprocessed_dataset_dir.exists(): preprocess_dataset(files, preprocessed_dataset_dir, val)
|
||||
if not preprocessed_dir.exists(): preprocess_dataset(files, preprocessed_dir, val)
|
||||
with tqdm(total=len(files)) as pbar:
|
||||
for x, _, _ in batch_load_unet3d(preprocessed_dataset_dir, val=val):
|
||||
for x, _, _ in batch_load_unet3d(preprocessed_dir, val=val):
|
||||
pbar.update(x.shape[0])
|
||||
|
||||
def load_resnet(val):
|
||||
|
||||
Reference in New Issue
Block a user