more cleanups on training script

This commit is contained in:
Francis Lata
2024-05-16 19:49:55 +00:00
parent 3b78aa7acc
commit 4d79bf0f34

View File

@@ -350,7 +350,7 @@ def train_unet3d():
from examples.mlperf.metrics import dice_score
from extra.models.unet3d import UNet3D
from extra.datasets.kits19 import iterate, get_train_files, get_val_files, sliding_window_inference, preprocess_dataset, BASEDIR
from tinygrad import Device, Tensor, GlobalCounters
from tinygrad import Device, Tensor, GlobalCounters, Context
from tinygrad.nn.optim import SGD
from math import ceil
@@ -360,10 +360,9 @@ def train_unet3d():
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
for x in GPUS: Device[x]
print(f"Training on {GPUS}")
TARGET_METRIC = 0.908
NUM_EPOCHS = getenv("NUM_EPOCHS", 4000)
NUM_EPOCHS = getenv("NUM_EPOCHS", 2500)
BS = getenv("BS", 1 * len(GPUS))
LR = getenv("LR", 2.0 * (BS / 28))
LR_WARMUP_EPOCHS = getenv("LR_WARMUP_EPOCHS", 1000)
@@ -371,14 +370,12 @@ def train_unet3d():
WANDB = getenv("WANDB")
PROJ_NAME = getenv("PROJ_NAME", "tinygrad_unet3d_mlperf")
SEED = getenv("SEED")
TRAIN_DATASET_SIZE = len(get_train_files())
VAL_DATASET_SIZE = len(get_val_files())
TRAIN_DATASET_SIZE, VAL_DATASET_SIZE = len(get_train_files()), len(get_val_files())
SAMPLES_PER_EPOCH = TRAIN_DATASET_SIZE // BS
START_EVAL_AT = getenv("START_EVAL_AT", ceil(1000 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
EVALUATE_EVERY = getenv("EVALUATE_EVERY", ceil(20 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
PREPROCESSED_DIR = BASEDIR / ".." / "preprocessed"
TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
TRAIN_BEAM, EVAL_BEAM = getenv("TRAIN_BEAM", BEAM.value), getenv("EVAL_BEAM", BEAM.value)
BENCHMARK = getenv("BENCHMARK")
CKPT = getenv("CKPT")
@@ -453,93 +450,92 @@ def train_unet3d():
is_successful, diverged = False, False
start_eval_at, evaluate_every = 1 if BENCHMARK else START_EVAL_AT, 1 if BENCHMARK else EVALUATE_EVERY
next_eval_at = start_eval_at
print(f"Eval starts at epoch {start_eval_at} and every {evaluate_every} epochs afterwards")
print(f"Training on {GPUS}")
print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards")
if not PREPROCESSED_DIR.exists():
preprocess_dataset(get_train_files(), PREPROCESSED_DIR, False)
preprocess_dataset(get_val_files(), PREPROCESSED_DIR, True)
for epoch in range(1, NUM_EPOCHS + 1):
BEAM.value = TRAIN_BEAM
with Context(BEAM=TRAIN_BEAM):
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
st = time.perf_counter()
st = time.perf_counter()
for i, (x, y) in enumerate(tqdm(iterate(get_train_files(), preprocessed_dir=PREPROCESSED_DIR, val=False, shuffle=True, bs=BS), total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK), start=1):
dt = time.perf_counter()
for i, (x, y) in enumerate(tqdm(iterate(get_train_files(), preprocessed_dir=PREPROCESSED_DIR, val=False, shuffle=True, bs=BS), total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK), start=1):
dt = time.perf_counter()
GlobalCounters.reset()
GlobalCounters.reset()
x, y = Tensor(x).realize().shard(GPUS, axis=0), Tensor(y, requires_grad=False).shard(GPUS, axis=0)
x, y = Tensor(x).realize().shard(GPUS, axis=0), Tensor(y, requires_grad=False).shard(GPUS, axis=0)
loss = train_step(model, x, y)
pt = time.perf_counter()
loss = train_step(model, x, y)
pt = time.perf_counter()
loss = loss.numpy().item()
cl = time.perf_counter()
loss = loss.numpy().item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
if BENCHMARK:
step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - st) * 1000.0:6.2f} ms fetch data, "
f"{loss:5.3f} loss, {optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - st) * 1000.0:6.2f} ms fetch data, "
f"{loss:5.3f} loss, {optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
)
if WANDB:
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - st,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
if WANDB:
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - st,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
st = cl
st = cl
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * SAMPLES_PER_EPOCH * NUM_EPOCHS / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
if (TRAIN_BEAM or EVAL_BEAM) and epoch == start_epoch: break
return
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * SAMPLES_PER_EPOCH * NUM_EPOCHS / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
if (TRAIN_BEAM or EVAL_BEAM) and epoch == start_epoch: break
return
with Context(BEAM=EVAL_BEAM):
if epoch == next_eval_at:
train_step.reset()
if epoch == next_eval_at:
train_step.reset()
next_eval_at += evaluate_every
eval_loss = []
scores = []
BEAM.value = EVAL_BEAM
for x, y in tqdm(iterate(get_val_files(), preprocessed_dir=PREPROCESSED_DIR), total=VAL_DATASET_SIZE):
eval_loss_value, score = eval_step(model, x, y)
eval_loss.append(eval_loss_value)
scores.append(score)
next_eval_at += evaluate_every
eval_loss = []
scores = []
scores = Tensor.mean(Tensor.stack(scores, dim=0), axis=0).numpy()
eval_loss = Tensor.mean(Tensor.stack(eval_loss, dim=0), axis=0).numpy()
for x, y in tqdm(iterate(get_val_files(), preprocessed_dir=PREPROCESSED_DIR), total=VAL_DATASET_SIZE):
eval_loss_value, score = eval_step(model, x, y)
eval_loss.append(eval_loss_value)
scores.append(score)
l1_dice, l2_dice = scores[0][-2], scores[0][-1]
mean_dice = (l2_dice + l1_dice) / 2
scores = Tensor.mean(Tensor.stack(scores, dim=0), axis=0).numpy()
eval_loss = Tensor.mean(Tensor.stack(eval_loss, dim=0), axis=0).numpy()
tqdm.write(f"{l1_dice} L1 dice, {l2_dice} L2 dice, {mean_dice:.3f} mean_dice, {eval_loss:5.2f} eval_loss")
l1_dice, l2_dice = scores[0][-2], scores[0][-1]
mean_dice = (l2_dice + l1_dice) / 2
if WANDB:
wandb.log({"eval/loss": eval_loss, "eval/mean_dice": mean_dice, "epoch": epoch})
tqdm.write(f"{l1_dice} L1 dice, {l2_dice} L2 dice, {mean_dice:.3f} mean_dice, {eval_loss:5.2f} eval_loss")
if mean_dice >= TARGET_METRIC:
is_successful = True
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
elif mean_dice < 1e-6:
print("Model diverging. Aborting.")
diverged = True
if WANDB:
wandb.log({"eval/loss": eval_loss, "eval/mean_dice": mean_dice, "epoch": epoch})
if CKPT:
if WANDB and wandb.run is not None:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{epoch}.safe"
else:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{epoch}.safe"
if mean_dice >= TARGET_METRIC:
is_successful = True
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
elif mean_dice < 1e-6:
print("Model diverging. Aborting.")
diverged = True
if CKPT:
if WANDB and wandb.run is not None:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{epoch}.safe"
else:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{epoch}.safe"
save_checkpoint(get_state_dict(model), fn)
save_checkpoint(get_state_dict(model), fn)
if is_successful or diverged:
break