mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
more cleanups on training script
This commit is contained in:
@@ -350,7 +350,7 @@ def train_unet3d():
|
||||
from examples.mlperf.metrics import dice_score
|
||||
from extra.models.unet3d import UNet3D
|
||||
from extra.datasets.kits19 import iterate, get_train_files, get_val_files, sliding_window_inference, preprocess_dataset, BASEDIR
|
||||
from tinygrad import Device, Tensor, GlobalCounters
|
||||
from tinygrad import Device, Tensor, GlobalCounters, Context
|
||||
from tinygrad.nn.optim import SGD
|
||||
from math import ceil
|
||||
|
||||
@@ -360,10 +360,9 @@ def train_unet3d():
|
||||
|
||||
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
||||
for x in GPUS: Device[x]
|
||||
print(f"Training on {GPUS}")
|
||||
|
||||
TARGET_METRIC = 0.908
|
||||
NUM_EPOCHS = getenv("NUM_EPOCHS", 4000)
|
||||
NUM_EPOCHS = getenv("NUM_EPOCHS", 2500)
|
||||
BS = getenv("BS", 1 * len(GPUS))
|
||||
LR = getenv("LR", 2.0 * (BS / 28))
|
||||
LR_WARMUP_EPOCHS = getenv("LR_WARMUP_EPOCHS", 1000)
|
||||
@@ -371,14 +370,12 @@ def train_unet3d():
|
||||
WANDB = getenv("WANDB")
|
||||
PROJ_NAME = getenv("PROJ_NAME", "tinygrad_unet3d_mlperf")
|
||||
SEED = getenv("SEED")
|
||||
TRAIN_DATASET_SIZE = len(get_train_files())
|
||||
VAL_DATASET_SIZE = len(get_val_files())
|
||||
TRAIN_DATASET_SIZE, VAL_DATASET_SIZE = len(get_train_files()), len(get_val_files())
|
||||
SAMPLES_PER_EPOCH = TRAIN_DATASET_SIZE // BS
|
||||
START_EVAL_AT = getenv("START_EVAL_AT", ceil(1000 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
|
||||
EVALUATE_EVERY = getenv("EVALUATE_EVERY", ceil(20 * TRAIN_DATASET_SIZE / (SAMPLES_PER_EPOCH * BS)))
|
||||
PREPROCESSED_DIR = BASEDIR / ".." / "preprocessed"
|
||||
TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
|
||||
EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
|
||||
TRAIN_BEAM, EVAL_BEAM = getenv("TRAIN_BEAM", BEAM.value), getenv("EVAL_BEAM", BEAM.value)
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
CKPT = getenv("CKPT")
|
||||
|
||||
@@ -453,93 +450,92 @@ def train_unet3d():
|
||||
is_successful, diverged = False, False
|
||||
start_eval_at, evaluate_every = 1 if BENCHMARK else START_EVAL_AT, 1 if BENCHMARK else EVALUATE_EVERY
|
||||
next_eval_at = start_eval_at
|
||||
print(f"Eval starts at epoch {start_eval_at} and every {evaluate_every} epochs afterwards")
|
||||
|
||||
print(f"Training on {GPUS}")
|
||||
print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards")
|
||||
|
||||
if not PREPROCESSED_DIR.exists():
|
||||
preprocess_dataset(get_train_files(), PREPROCESSED_DIR, False)
|
||||
preprocess_dataset(get_val_files(), PREPROCESSED_DIR, True)
|
||||
|
||||
for epoch in range(1, NUM_EPOCHS + 1):
|
||||
BEAM.value = TRAIN_BEAM
|
||||
with Context(BEAM=TRAIN_BEAM):
|
||||
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
|
||||
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
|
||||
|
||||
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
|
||||
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
|
||||
st = time.perf_counter()
|
||||
|
||||
st = time.perf_counter()
|
||||
for i, (x, y) in enumerate(tqdm(iterate(get_train_files(), preprocessed_dir=PREPROCESSED_DIR, val=False, shuffle=True, bs=BS), total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK), start=1):
|
||||
dt = time.perf_counter()
|
||||
|
||||
for i, (x, y) in enumerate(tqdm(iterate(get_train_files(), preprocessed_dir=PREPROCESSED_DIR, val=False, shuffle=True, bs=BS), total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK), start=1):
|
||||
dt = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
|
||||
GlobalCounters.reset()
|
||||
x, y = Tensor(x).realize().shard(GPUS, axis=0), Tensor(y, requires_grad=False).shard(GPUS, axis=0)
|
||||
|
||||
x, y = Tensor(x).realize().shard(GPUS, axis=0), Tensor(y, requires_grad=False).shard(GPUS, axis=0)
|
||||
loss = train_step(model, x, y)
|
||||
pt = time.perf_counter()
|
||||
|
||||
loss = train_step(model, x, y)
|
||||
pt = time.perf_counter()
|
||||
loss = loss.numpy().item()
|
||||
cl = time.perf_counter()
|
||||
|
||||
loss = loss.numpy().item()
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
|
||||
if BENCHMARK:
|
||||
step_times.append(cl - st)
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - st) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{loss:5.3f} loss, {optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
|
||||
)
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - st) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{loss:5.3f} loss, {optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
|
||||
)
|
||||
if WANDB:
|
||||
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - st,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - st,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
|
||||
st = cl
|
||||
|
||||
st = cl
|
||||
if i == BENCHMARK:
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * SAMPLES_PER_EPOCH * NUM_EPOCHS / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
if (TRAIN_BEAM or EVAL_BEAM) and epoch == start_epoch: break
|
||||
return
|
||||
|
||||
if i == BENCHMARK:
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * SAMPLES_PER_EPOCH * NUM_EPOCHS / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
if (TRAIN_BEAM or EVAL_BEAM) and epoch == start_epoch: break
|
||||
return
|
||||
with Context(BEAM=EVAL_BEAM):
|
||||
if epoch == next_eval_at:
|
||||
train_step.reset()
|
||||
|
||||
if epoch == next_eval_at:
|
||||
train_step.reset()
|
||||
next_eval_at += evaluate_every
|
||||
eval_loss = []
|
||||
scores = []
|
||||
|
||||
BEAM.value = EVAL_BEAM
|
||||
for x, y in tqdm(iterate(get_val_files(), preprocessed_dir=PREPROCESSED_DIR), total=VAL_DATASET_SIZE):
|
||||
eval_loss_value, score = eval_step(model, x, y)
|
||||
eval_loss.append(eval_loss_value)
|
||||
scores.append(score)
|
||||
|
||||
next_eval_at += evaluate_every
|
||||
eval_loss = []
|
||||
scores = []
|
||||
scores = Tensor.mean(Tensor.stack(scores, dim=0), axis=0).numpy()
|
||||
eval_loss = Tensor.mean(Tensor.stack(eval_loss, dim=0), axis=0).numpy()
|
||||
|
||||
for x, y in tqdm(iterate(get_val_files(), preprocessed_dir=PREPROCESSED_DIR), total=VAL_DATASET_SIZE):
|
||||
eval_loss_value, score = eval_step(model, x, y)
|
||||
eval_loss.append(eval_loss_value)
|
||||
scores.append(score)
|
||||
l1_dice, l2_dice = scores[0][-2], scores[0][-1]
|
||||
mean_dice = (l2_dice + l1_dice) / 2
|
||||
|
||||
scores = Tensor.mean(Tensor.stack(scores, dim=0), axis=0).numpy()
|
||||
eval_loss = Tensor.mean(Tensor.stack(eval_loss, dim=0), axis=0).numpy()
|
||||
tqdm.write(f"{l1_dice} L1 dice, {l2_dice} L2 dice, {mean_dice:.3f} mean_dice, {eval_loss:5.2f} eval_loss")
|
||||
|
||||
l1_dice, l2_dice = scores[0][-2], scores[0][-1]
|
||||
mean_dice = (l2_dice + l1_dice) / 2
|
||||
if WANDB:
|
||||
wandb.log({"eval/loss": eval_loss, "eval/mean_dice": mean_dice, "epoch": epoch})
|
||||
|
||||
tqdm.write(f"{l1_dice} L1 dice, {l2_dice} L2 dice, {mean_dice:.3f} mean_dice, {eval_loss:5.2f} eval_loss")
|
||||
if mean_dice >= TARGET_METRIC:
|
||||
is_successful = True
|
||||
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
|
||||
elif mean_dice < 1e-6:
|
||||
print("Model diverging. Aborting.")
|
||||
diverged = True
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/loss": eval_loss, "eval/mean_dice": mean_dice, "epoch": epoch})
|
||||
if CKPT:
|
||||
if WANDB and wandb.run is not None:
|
||||
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{epoch}.safe"
|
||||
else:
|
||||
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{epoch}.safe"
|
||||
|
||||
if mean_dice >= TARGET_METRIC:
|
||||
is_successful = True
|
||||
save_checkpoint(get_state_dict(model), f"./ckpts/unet3d.safe")
|
||||
elif mean_dice < 1e-6:
|
||||
print("Model diverging. Aborting.")
|
||||
diverged = True
|
||||
|
||||
if CKPT:
|
||||
if WANDB and wandb.run is not None:
|
||||
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{epoch}.safe"
|
||||
else:
|
||||
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{epoch}.safe"
|
||||
|
||||
save_checkpoint(get_state_dict(model), fn)
|
||||
save_checkpoint(get_state_dict(model), fn)
|
||||
|
||||
if is_successful or diverged:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user