mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
UNet3D MLPerf (#3470)
* add training set transforms * add DICE cross entropy loss * convert pred and label to Tensor when calculating DICE score * cleanups and allow train dataset batching * fix DICE CE loss calculation * jitted training step * clean up DICE CE loss calculation * initial support for sharding * Revert "initial support for sharding" This reverts commite3670813b8. * minor updates * cleanup imports * add support for sharding * apply temp patch to try to avoid OOM * revert cstyle changes * add gradient acc * hotfix * add FP16 support * add ability to train on smaller image sizes * add support for saving and loading checkpoints + cleanup some various modes * fix issue with using smaller patch size + update W&B logging * disable LR_WARMUP_EPOCHS * updates * minor cleanups * cleanup * update order of transformations * more cleanups * realize loss * cleanup * more cleanup * some cleanups * add RAM usage * minor cleanups * add support for gradient accumulation * cleanup imports * minor updates to not use GA_STEPS * remove FP16 option since it's available now globally * update multi-GPU setup * add timing logs for training loop * go back to using existing dataloader and add ability to preprocess data to save time * clean up optimization and re-enable JIT and multi-GPU support for training and evaluation * free train and eval steps memory * cleanups and scale batch size based on the number of GPUs * fix GlobalCounters import * fix seed * fix W&B setup * update batch size default size * add back metric divergence check * put back JIT on UNet3d eval * move dataset preprocessing inside training code * add test for dice_loss * add config logging support to W&B and other cleanups * change how default float is getting retrieved * remove TinyJit import duplicate * update config logging to W&B and remove JIT on eval_step * no need for caching preprocessed data anymore * fix how evaluation is ran and how often * add support for LR scaling * fix issue with gaussian being moved to scipy.signal.windows * remove DICE loss unit test * fix issue where loss isn't compatible with multiGPU * add individual BEAM control for train and eval steps * fix ndimage scipy import * add BENCHMARK * cleanups on BENCHMARK + fix on rand_flip augmentation during training * cleanup train and eval BEAM envs * add checkpointing support after every eval * cleanup model_eval * disable grad during eval * use new preprocessing dataset mechanism * remove unused import * use training and inference_mode contexts * start eval after benchmarking * add data fetching time * cleanup decorators * more cleanups on training script * add message during benchmarking mode * realize when reassigning LR on scheduler and update default number of epochs * add JIT on eval step * remove JIT on eval_step * add train dataloader for unet3d * move checkpointing to be done after every epoch * revert removal of JIT on unet3d inference * save checkpoint if metric is not successful * Revert "add train dataloader for unet3d" This reverts commitc166d129df. * Revert "Revert "add train dataloader for unet3d"" This reverts commit36366c65d2. * hotfix: seed was defaulting to a value of 0 * fix SEED value * remove the usage of context managers for setting BEAM and going from training to inference * support new stack API for calculating eval loss and metric * Revert "remove the usage of context managers for setting BEAM and going from training to inference" This reverts commit2c0ba8d322. * check training and test preprocessed folders separately * clean up imports and log FUSE_CONV_BW * use train and val preprocessing constants * add kits19 dataset setup script * update to use the new test decorator for disabling grad * update kits19 dataset setup script * add docs on how to train the model * set default value for BASEDIR * add detailed instruction about BASEDIR usage --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from tqdm import tqdm
|
||||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
|
||||
|
||||
@@ -346,8 +346,225 @@ def train_retinanet():
|
||||
pass
|
||||
|
||||
def train_unet3d():
|
||||
# TODO: Unet3d
|
||||
pass
|
||||
"""
|
||||
Trains the UNet3D model.
|
||||
|
||||
Instructions:
|
||||
1) Run the following script from the root folder of `tinygrad`:
|
||||
```./examples/mlperf/scripts/setup_kits19_dataset.sh```
|
||||
|
||||
Optionally, `BASEDIR` can be set to download and process the dataset at a specific location:
|
||||
```BASEDIR=<folder_path> ./examples/mlperf/scripts/setup_kits19_dataset.sh```
|
||||
|
||||
2) To start training the model, run the following:
|
||||
```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 FUSE_CONV_BW=1 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py```
|
||||
"""
|
||||
from examples.mlperf.losses import dice_ce_loss
|
||||
from examples.mlperf.metrics import dice_score
|
||||
from examples.mlperf.dataloader import batch_load_unet3d
|
||||
from extra.models.unet3d import UNet3D
|
||||
from extra.datasets.kits19 import iterate, get_train_files, get_val_files, sliding_window_inference, preprocess_dataset, TRAIN_PREPROCESSED_DIR, VAL_PREPROCESSED_DIR
|
||||
from tinygrad import Context
|
||||
from tinygrad.nn.optim import SGD
|
||||
from math import ceil
|
||||
|
||||
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
||||
for x in GPUS: Device[x]
|
||||
|
||||
TARGET_METRIC = 0.908
|
||||
NUM_EPOCHS = getenv("NUM_EPOCHS", 4000)
|
||||
BS = getenv("BS", 1 * len(GPUS))
|
||||
LR = getenv("LR", 2.0 * (BS / 28))
|
||||
LR_WARMUP_EPOCHS = getenv("LR_WARMUP_EPOCHS", 1000)
|
||||
LR_WARMUP_INIT_LR = getenv("LR_WARMUP_INIT_LR", 0.0001)
|
||||
WANDB = getenv("WANDB")
|
||||
PROJ_NAME = getenv("PROJ_NAME", "tinygrad_unet3d_mlperf")
|
||||
SEED = getenv("SEED", -1) if getenv("SEED", -1) >= 0 else None
|
||||
TRAIN_DATASET_SIZE, VAL_DATASET_SIZE = len(get_train_files()), len(get_val_files())
|
||||
SAMPLES_PER_EPOCH = TRAIN_DATASET_SIZE // BS
|
||||
START_EVAL_AT = getenv("START_EVAL_AT", ceil(1000 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
|
||||
EVALUATE_EVERY = getenv("EVALUATE_EVERY", ceil(20 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
|
||||
TRAIN_BEAM, EVAL_BEAM = getenv("TRAIN_BEAM", BEAM.value), getenv("EVAL_BEAM", BEAM.value)
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
CKPT = getenv("CKPT")
|
||||
|
||||
config = {
|
||||
"num_epochs": NUM_EPOCHS,
|
||||
"batch_size": BS,
|
||||
"learning_rate": LR,
|
||||
"learning_rate_warmup_epochs": LR_WARMUP_EPOCHS,
|
||||
"learning_rate_warmup_init": LR_WARMUP_INIT_LR,
|
||||
"start_eval_at": START_EVAL_AT,
|
||||
"evaluate_every": EVALUATE_EVERY,
|
||||
"train_beam": TRAIN_BEAM,
|
||||
"eval_beam": EVAL_BEAM,
|
||||
"wino": WINO.value,
|
||||
"fuse_conv_bw": FUSE_CONV_BW.value,
|
||||
"gpus": GPUS,
|
||||
"default_float": dtypes.default_float.name
|
||||
}
|
||||
|
||||
if WANDB:
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise "Need to install wandb to use it"
|
||||
|
||||
if SEED is not None:
|
||||
config["seed"] = SEED
|
||||
Tensor.manual_seed(SEED)
|
||||
|
||||
model = UNet3D()
|
||||
params = get_parameters(model)
|
||||
|
||||
for p in params: p.realize().to_(GPUS)
|
||||
|
||||
optim = SGD(params, lr=LR, momentum=0.9, nesterov=True)
|
||||
|
||||
def lr_warm_up(optim, init_lr, lr, current_epoch, warmup_epochs):
|
||||
scale = current_epoch / warmup_epochs
|
||||
optim.lr.assign(Tensor([init_lr + (lr - init_lr) * scale], device=GPUS)).realize()
|
||||
|
||||
def save_checkpoint(state_dict, fn):
|
||||
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
|
||||
print(f"saving checkpoint to {fn}")
|
||||
safe_save(state_dict, fn)
|
||||
|
||||
def data_get(it):
|
||||
x, y, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(model, x, y):
|
||||
optim.zero_grad()
|
||||
|
||||
y_hat = model(x)
|
||||
loss = dice_ce_loss(y_hat, y)
|
||||
|
||||
loss.backward()
|
||||
optim.step()
|
||||
return loss.realize()
|
||||
|
||||
@Tensor.train(mode=False)
|
||||
@Tensor.test()
|
||||
def eval_step(model, x, y):
|
||||
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
|
||||
y_hat, y = Tensor(y_hat), Tensor(y, requires_grad=False)
|
||||
loss = dice_ce_loss(y_hat, y)
|
||||
score = dice_score(y_hat, y)
|
||||
return loss.realize(), score.realize()
|
||||
|
||||
if WANDB: wandb.init(config=config, project=PROJ_NAME)
|
||||
|
||||
step_times, start_epoch = [], 1
|
||||
is_successful, diverged = False, False
|
||||
start_eval_at, evaluate_every = 1 if BENCHMARK else START_EVAL_AT, 1 if BENCHMARK else EVALUATE_EVERY
|
||||
next_eval_at = start_eval_at
|
||||
|
||||
print(f"Training on {GPUS}")
|
||||
|
||||
if BENCHMARK: print("Benchmarking UNet3D")
|
||||
else: print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards")
|
||||
|
||||
if not TRAIN_PREPROCESSED_DIR.exists(): preprocess_dataset(get_train_files(), TRAIN_PREPROCESSED_DIR, False)
|
||||
if not VAL_PREPROCESSED_DIR.exists(): preprocess_dataset(get_val_files(), VAL_PREPROCESSED_DIR, True)
|
||||
|
||||
for epoch in range(1, NUM_EPOCHS + 1):
|
||||
with Context(BEAM=TRAIN_BEAM):
|
||||
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
|
||||
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
|
||||
|
||||
train_dataloader = batch_load_unet3d(TRAIN_PREPROCESSED_DIR, batch_size=BS, val=False, shuffle=True, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK))
|
||||
i, proc = 0, data_get(it)
|
||||
|
||||
prev_cookies = []
|
||||
st = time.perf_counter()
|
||||
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
|
||||
loss, proc = train_step(model, proc[0], proc[1]), proc[2]
|
||||
|
||||
pt = time.perf_counter()
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = data_get(it)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
dt = time.perf_counter()
|
||||
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
loss = loss.numpy().item()
|
||||
|
||||
cl = time.perf_counter()
|
||||
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optim.lr.numpy()[0]:.6f} LR, "
|
||||
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
|
||||
)
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt,
|
||||
"train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
|
||||
|
||||
st = cl
|
||||
prev_cookies.append(proc)
|
||||
proc, next_proc = next_proc, None # return old cookie
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * SAMPLES_PER_EPOCH * NUM_EPOCHS / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
if (TRAIN_BEAM or EVAL_BEAM) and epoch == start_epoch: break
|
||||
return
|
||||
|
||||
with Context(BEAM=EVAL_BEAM):
|
||||
if epoch == next_eval_at:
|
||||
next_eval_at += evaluate_every
|
||||
eval_loss = []
|
||||
scores = []
|
||||
|
||||
for x, y in tqdm(iterate(get_val_files(), preprocessed_dir=VAL_PREPROCESSED_DIR), total=VAL_DATASET_SIZE):
|
||||
eval_loss_value, score = eval_step(model, x, y)
|
||||
eval_loss.append(eval_loss_value)
|
||||
scores.append(score)
|
||||
|
||||
scores = Tensor.mean(Tensor.stack(*scores, dim=0), axis=0).numpy()
|
||||
eval_loss = Tensor.mean(Tensor.stack(*eval_loss, dim=0), axis=0).numpy()
|
||||
|
||||
l1_dice, l2_dice = scores[0][-2], scores[0][-1]
|
||||
mean_dice = (l2_dice + l1_dice) / 2
|
||||
|
||||
tqdm.write(f"{l1_dice} L1 dice, {l2_dice} L2 dice, {mean_dice:.3f} mean_dice, {eval_loss:5.2f} eval_loss")
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/loss": eval_loss, "eval/mean_dice": mean_dice, "epoch": epoch})
|
||||
|
||||
if mean_dice >= TARGET_METRIC:
|
||||
is_successful = True
|
||||
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
|
||||
elif mean_dice < 1e-6:
|
||||
print("Model diverging. Aborting.")
|
||||
diverged = True
|
||||
|
||||
if not is_successful and CKPT:
|
||||
if WANDB and wandb.run is not None:
|
||||
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{epoch}.safe"
|
||||
else:
|
||||
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{epoch}.safe"
|
||||
|
||||
save_checkpoint(get_state_dict(model), fn)
|
||||
|
||||
if is_successful or diverged:
|
||||
break
|
||||
|
||||
def train_rnnt():
|
||||
# TODO: RNN-T
|
||||
|
||||
19
examples/mlperf/scripts/setup_kits19_dataset.sh
Executable file
19
examples/mlperf/scripts/setup_kits19_dataset.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ -z $BASEDIR ]; then
|
||||
export BASEDIR="./extra/datasets/"
|
||||
fi
|
||||
|
||||
cd $BASEDIR
|
||||
if [ -d "kits19" ]; then
|
||||
echo "kits19 dataset is already available"
|
||||
else
|
||||
echo "Downloading and preparing kits19 dataset at $BASEDIR"
|
||||
|
||||
git clone https://github.com/neheller/kits19
|
||||
cd kits19
|
||||
pip3 install -r requirements.txt
|
||||
python3 -m starter_code.get_imaging
|
||||
|
||||
echo "Done"
|
||||
fi
|
||||
@@ -15,18 +15,6 @@ BASEDIR = Path(__file__).parent / "kits19" / "data"
|
||||
TRAIN_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "train"
|
||||
VAL_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "val"
|
||||
|
||||
"""
|
||||
To download the dataset:
|
||||
```sh
|
||||
git clone https://github.com/neheller/kits19
|
||||
cd kits19
|
||||
pip3 install -r requirements.txt
|
||||
python3 -m starter_code.get_imaging
|
||||
cd ..
|
||||
mv kits19 extra/datasets
|
||||
```
|
||||
"""
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_train_files():
|
||||
return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
|
||||
@@ -123,7 +111,7 @@ def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-
|
||||
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
|
||||
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
|
||||
|
||||
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
|
||||
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5, gpus=None):
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
mdl_run = TinyJit(lambda x: model(x).realize())
|
||||
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
|
||||
@@ -152,7 +140,7 @@ def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), o
|
||||
for i in range(0, strides[0] * size[0], strides[0]):
|
||||
for j in range(0, strides[1] * size[1], strides[1]):
|
||||
for k in range(0, strides[2] * size[2], strides[2]):
|
||||
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy()
|
||||
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k], device=gpus)).numpy()
|
||||
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
|
||||
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
|
||||
result /= norm_map
|
||||
|
||||
Reference in New Issue
Block a user