UNet3D MLPerf (#3470)

* add training set transforms

* add DICE cross entropy loss

* convert pred and label to Tensor when calculating DICE score

* cleanups and allow train dataset batching

* fix DICE CE loss calculation

* jitted training step

* clean up DICE CE loss calculation

* initial support for sharding

* Revert "initial support for sharding"

This reverts commit e3670813b8.

* minor updates

* cleanup imports

* add support for sharding

* apply temp patch to try to avoid OOM

* revert cstyle changes

* add gradient acc

* hotfix

* add FP16 support

* add ability to train on smaller image sizes

* add support for saving and loading checkpoints + cleanup some various modes

* fix issue with using smaller patch size + update W&B logging

* disable LR_WARMUP_EPOCHS

* updates

* minor cleanups

* cleanup

* update order of transformations

* more cleanups

* realize loss

* cleanup

* more cleanup

* some cleanups

* add RAM usage

* minor cleanups

* add support for gradient accumulation

* cleanup imports

* minor updates to not use GA_STEPS

* remove FP16 option since it's available now globally

* update multi-GPU setup

* add timing logs for training loop

* go back to using existing dataloader and add ability to preprocess data to save time

* clean up optimization and re-enable JIT and multi-GPU support for training and evaluation

* free train and eval steps memory

* cleanups and scale batch size based on the number of GPUs

* fix GlobalCounters import

* fix seed

* fix W&B setup

* update batch size default size

* add back metric divergence check

* put back JIT on UNet3d eval

* move dataset preprocessing inside training code

* add test for dice_loss

* add config logging support to W&B and other cleanups

* change how default float is getting retrieved

* remove TinyJit import duplicate

* update config logging to W&B and remove JIT on eval_step

* no need for caching preprocessed data anymore

* fix how evaluation is ran and how often

* add support for LR scaling

* fix issue with gaussian being moved to scipy.signal.windows

* remove DICE loss unit test

* fix issue where loss isn't compatible with multiGPU

* add individual BEAM control for train and eval steps

* fix ndimage scipy import

* add BENCHMARK

* cleanups on BENCHMARK + fix on rand_flip augmentation during training

* cleanup train and eval BEAM envs

* add checkpointing support after every eval

* cleanup model_eval

* disable grad during eval

* use new preprocessing dataset mechanism

* remove unused import

* use training and inference_mode contexts

* start eval after benchmarking

* add data fetching time

* cleanup decorators

* more cleanups on training script

* add message during benchmarking mode

* realize when reassigning LR on scheduler and update default number of epochs

* add JIT on eval step

* remove JIT on eval_step

* add train dataloader for unet3d

* move checkpointing to be done after every epoch

* revert removal of JIT on unet3d inference

* save checkpoint if metric is not successful

* Revert "add train dataloader for unet3d"

This reverts commit c166d129df.

* Revert "Revert "add train dataloader for unet3d""

This reverts commit 36366c65d2.

* hotfix: seed was defaulting to a value of 0

* fix SEED value

* remove the usage of context managers for setting BEAM and going from training to inference

* support new stack API for calculating eval loss and metric

* Revert "remove the usage of context managers for setting BEAM and going from training to inference"

This reverts commit 2c0ba8d322.

* check training and test preprocessed folders separately

* clean up imports and log FUSE_CONV_BW

* use train and val preprocessing constants

* add kits19 dataset setup script

* update to use the new test decorator for disabling grad

* update kits19 dataset setup script

* add docs on how to train the model

* set default value for BASEDIR

* add detailed instruction about BASEDIR usage

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Francis Lata
2024-09-10 04:37:28 -04:00
committed by GitHub
parent f4f705a07c
commit b7ce9a1530
3 changed files with 241 additions and 17 deletions

View File

@@ -4,7 +4,7 @@ from tqdm import tqdm
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
@@ -346,8 +346,225 @@ def train_retinanet():
pass
def train_unet3d():
# TODO: Unet3d
pass
"""
Trains the UNet3D model.
Instructions:
1) Run the following script from the root folder of `tinygrad`:
```./examples/mlperf/scripts/setup_kits19_dataset.sh```
Optionally, `BASEDIR` can be set to download and process the dataset at a specific location:
```BASEDIR=<folder_path> ./examples/mlperf/scripts/setup_kits19_dataset.sh```
2) To start training the model, run the following:
```time PYTHONPATH=. WANDB=1 TRAIN_BEAM=3 FUSE_CONV_BW=1 GPUS=6 BS=6 MODEL=unet3d python3 examples/mlperf/model_train.py```
"""
from examples.mlperf.losses import dice_ce_loss
from examples.mlperf.metrics import dice_score
from examples.mlperf.dataloader import batch_load_unet3d
from extra.models.unet3d import UNet3D
from extra.datasets.kits19 import iterate, get_train_files, get_val_files, sliding_window_inference, preprocess_dataset, TRAIN_PREPROCESSED_DIR, VAL_PREPROCESSED_DIR
from tinygrad import Context
from tinygrad.nn.optim import SGD
from math import ceil
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
for x in GPUS: Device[x]
TARGET_METRIC = 0.908
NUM_EPOCHS = getenv("NUM_EPOCHS", 4000)
BS = getenv("BS", 1 * len(GPUS))
LR = getenv("LR", 2.0 * (BS / 28))
LR_WARMUP_EPOCHS = getenv("LR_WARMUP_EPOCHS", 1000)
LR_WARMUP_INIT_LR = getenv("LR_WARMUP_INIT_LR", 0.0001)
WANDB = getenv("WANDB")
PROJ_NAME = getenv("PROJ_NAME", "tinygrad_unet3d_mlperf")
SEED = getenv("SEED", -1) if getenv("SEED", -1) >= 0 else None
TRAIN_DATASET_SIZE, VAL_DATASET_SIZE = len(get_train_files()), len(get_val_files())
SAMPLES_PER_EPOCH = TRAIN_DATASET_SIZE // BS
START_EVAL_AT = getenv("START_EVAL_AT", ceil(1000 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
EVALUATE_EVERY = getenv("EVALUATE_EVERY", ceil(20 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
TRAIN_BEAM, EVAL_BEAM = getenv("TRAIN_BEAM", BEAM.value), getenv("EVAL_BEAM", BEAM.value)
BENCHMARK = getenv("BENCHMARK")
CKPT = getenv("CKPT")
config = {
"num_epochs": NUM_EPOCHS,
"batch_size": BS,
"learning_rate": LR,
"learning_rate_warmup_epochs": LR_WARMUP_EPOCHS,
"learning_rate_warmup_init": LR_WARMUP_INIT_LR,
"start_eval_at": START_EVAL_AT,
"evaluate_every": EVALUATE_EVERY,
"train_beam": TRAIN_BEAM,
"eval_beam": EVAL_BEAM,
"wino": WINO.value,
"fuse_conv_bw": FUSE_CONV_BW.value,
"gpus": GPUS,
"default_float": dtypes.default_float.name
}
if WANDB:
try:
import wandb
except ImportError:
raise "Need to install wandb to use it"
if SEED is not None:
config["seed"] = SEED
Tensor.manual_seed(SEED)
model = UNet3D()
params = get_parameters(model)
for p in params: p.realize().to_(GPUS)
optim = SGD(params, lr=LR, momentum=0.9, nesterov=True)
def lr_warm_up(optim, init_lr, lr, current_epoch, warmup_epochs):
scale = current_epoch / warmup_epochs
optim.lr.assign(Tensor([init_lr + (lr - init_lr) * scale], device=GPUS)).realize()
def save_checkpoint(state_dict, fn):
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
print(f"saving checkpoint to {fn}")
safe_save(state_dict, fn)
def data_get(it):
x, y, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit
@Tensor.train()
def train_step(model, x, y):
optim.zero_grad()
y_hat = model(x)
loss = dice_ce_loss(y_hat, y)
loss.backward()
optim.step()
return loss.realize()
@Tensor.train(mode=False)
@Tensor.test()
def eval_step(model, x, y):
y_hat, y = sliding_window_inference(model, x, y, gpus=GPUS)
y_hat, y = Tensor(y_hat), Tensor(y, requires_grad=False)
loss = dice_ce_loss(y_hat, y)
score = dice_score(y_hat, y)
return loss.realize(), score.realize()
if WANDB: wandb.init(config=config, project=PROJ_NAME)
step_times, start_epoch = [], 1
is_successful, diverged = False, False
start_eval_at, evaluate_every = 1 if BENCHMARK else START_EVAL_AT, 1 if BENCHMARK else EVALUATE_EVERY
next_eval_at = start_eval_at
print(f"Training on {GPUS}")
if BENCHMARK: print("Benchmarking UNet3D")
else: print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards")
if not TRAIN_PREPROCESSED_DIR.exists(): preprocess_dataset(get_train_files(), TRAIN_PREPROCESSED_DIR, False)
if not VAL_PREPROCESSED_DIR.exists(): preprocess_dataset(get_val_files(), VAL_PREPROCESSED_DIR, True)
for epoch in range(1, NUM_EPOCHS + 1):
with Context(BEAM=TRAIN_BEAM):
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
train_dataloader = batch_load_unet3d(TRAIN_PREPROCESSED_DIR, batch_size=BS, val=False, shuffle=True, seed=SEED)
it = iter(tqdm(train_dataloader, total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK))
i, proc = 0, data_get(it)
prev_cookies = []
st = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
loss, proc = train_step(model, proc[0], proc[1]), proc[2]
pt = time.perf_counter()
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
next_proc = data_get(it)
except StopIteration:
next_proc = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.numpy().item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optim.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
)
if WANDB:
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt,
"train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
st = cl
prev_cookies.append(proc)
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * SAMPLES_PER_EPOCH * NUM_EPOCHS / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
if (TRAIN_BEAM or EVAL_BEAM) and epoch == start_epoch: break
return
with Context(BEAM=EVAL_BEAM):
if epoch == next_eval_at:
next_eval_at += evaluate_every
eval_loss = []
scores = []
for x, y in tqdm(iterate(get_val_files(), preprocessed_dir=VAL_PREPROCESSED_DIR), total=VAL_DATASET_SIZE):
eval_loss_value, score = eval_step(model, x, y)
eval_loss.append(eval_loss_value)
scores.append(score)
scores = Tensor.mean(Tensor.stack(*scores, dim=0), axis=0).numpy()
eval_loss = Tensor.mean(Tensor.stack(*eval_loss, dim=0), axis=0).numpy()
l1_dice, l2_dice = scores[0][-2], scores[0][-1]
mean_dice = (l2_dice + l1_dice) / 2
tqdm.write(f"{l1_dice} L1 dice, {l2_dice} L2 dice, {mean_dice:.3f} mean_dice, {eval_loss:5.2f} eval_loss")
if WANDB:
wandb.log({"eval/loss": eval_loss, "eval/mean_dice": mean_dice, "epoch": epoch})
if mean_dice >= TARGET_METRIC:
is_successful = True
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
elif mean_dice < 1e-6:
print("Model diverging. Aborting.")
diverged = True
if not is_successful and CKPT:
if WANDB and wandb.run is not None:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{epoch}.safe"
else:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{epoch}.safe"
save_checkpoint(get_state_dict(model), fn)
if is_successful or diverged:
break
def train_rnnt():
# TODO: RNN-T

View File

@@ -0,0 +1,19 @@
#!/bin/bash
if [ -z $BASEDIR ]; then
export BASEDIR="./extra/datasets/"
fi
cd $BASEDIR
if [ -d "kits19" ]; then
echo "kits19 dataset is already available"
else
echo "Downloading and preparing kits19 dataset at $BASEDIR"
git clone https://github.com/neheller/kits19
cd kits19
pip3 install -r requirements.txt
python3 -m starter_code.get_imaging
echo "Done"
fi

View File

@@ -15,18 +15,6 @@ BASEDIR = Path(__file__).parent / "kits19" / "data"
TRAIN_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "train"
VAL_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "val"
"""
To download the dataset:
```sh
git clone https://github.com/neheller/kits19
cd kits19
pip3 install -r requirements.txt
python3 -m starter_code.get_imaging
cd ..
mv kits19 extra/datasets
```
"""
@functools.lru_cache(None)
def get_train_files():
return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
@@ -123,7 +111,7 @@ def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5, gpus=None):
from tinygrad.engine.jit import TinyJit
mdl_run = TinyJit(lambda x: model(x).realize())
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
@@ -152,7 +140,7 @@ def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), o
for i in range(0, strides[0] * size[0], strides[0]):
for j in range(0, strides[1] * size[1], strides[1]):
for k in range(0, strides[2] * size[2], strides[2]):
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy()
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k], device=gpus)).numpy()
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
result /= norm_map