Files
tinygrad/examples/mlperf/model_train.py
2024-04-29 15:47:21 -04:00

520 lines
23 KiB
Python

import os, time, math, functools
from pathlib import Path
from tqdm import tqdm
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
from extra.lr_scheduler import LRSchedulerGroup
from examples.mlperf.helpers import get_training_state, load_training_state
def train_resnet():
from extra.models import resnet
from examples.mlperf.dataloader import batch_load_resnet
from extra.datasets.imagenet import get_train_files, get_val_files
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
from examples.mlperf.initializers import Conv2dHeNormal, Linear
from examples.hlb_cifar10 import UnsyncedBatchNorm
config = {}
seed = config["seed"] = getenv("SEED", 42)
Tensor.manual_seed(seed) # seed for weight initialization
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
print(f"Training on {GPUS}")
for x in GPUS: Device[x]
TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
# ** model definition and initializers **
num_classes = 1000
resnet.Conv2d = Conv2dHeNormal
resnet.Linear = Linear
if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=len(GPUS))
model = resnet.ResNet50(num_classes)
# shard weights and initialize in order
for k, x in get_state_dict(model).items():
if not getenv("SYNCBN") and ("running_mean" in k or "running_var" in k):
x.realize().shard_(GPUS, axis=0)
else:
x.realize().to_(GPUS)
parameters = get_parameters(model)
# ** hyperparameters **
epochs = config["epochs"] = getenv("EPOCHS", 37)
BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS)
base_lr = config["base_lr"] = getenv("LR", 7 * (BS/1536))
lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 2)
decay = config["decay"] = getenv("DECAY", 5e-5)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 128.0 if dtypes.default_float == dtypes.float16 else 1.0)
target, achieved = getenv("TARGET", 0.759), False
eval_start_epoch = getenv("EVAL_START_EPOCH", 0)
eval_epochs = getenv("EVAL_EPOCHS", 1)
steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS)
steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS)
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["BEAM"] = BEAM.value
config["TRAIN_BEAM"] = TRAIN_BEAM
config["EVAL_BEAM"] = EVAL_BEAM
config["WINO"] = WINO.value
config["SYNCBN"] = getenv("SYNCBN")
# ** Optimizer **
skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
parameters = [x for x in parameters if x not in set(skip_list)]
optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay)
optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True)
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
# ** LR scheduler **
scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4,
train_steps=epochs * steps_in_train_epoch,
warmup=lr_warmup_epochs * steps_in_train_epoch)
scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
print(f"training with batch size {BS} for {epochs} epochs")
# ** resume from checkpointing **
start_epoch = 0
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch)
print(f"resuming from {ckpt} at epoch {start_epoch}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args)
BENCHMARK = getenv("BENCHMARK")
# ** jitted steps **
input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
# mlperf reference resnet does not divide by input_std for some reason
# input_std = Tensor([0.229, 0.224, 0.225], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
@TinyJit
def train_step(X, Y):
optimizer_group.zero_grad()
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
(loss * loss_scaler).backward()
for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
optimizer_group.step()
scheduler_group.step()
return loss.realize(), top_1.realize()
@TinyJit
def eval_step(X, Y):
X = normalize(X)
out = model.forward(X)
loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
top_1 = (out.argmax(-1) == Y).sum()
return loss.realize(), top_1.realize()
def data_get(it):
x, y, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), cookie
# ** epoch loop **
step_times = []
for e in range(start_epoch, epochs):
# ** train loop **
Tensor.training = True
BEAM.value = TRAIN_BEAM
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e)
it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, data_get(it)
st = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
(loss, top_1_acc), proc = train_step(proc[0], proc[1]), proc[2]
pt = time.perf_counter()
try:
next_proc = data_get(it)
except StopIteration:
next_proc = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / BS
cl = time.perf_counter()
if BENCHMARK:
step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {top_1_acc:3.2f} acc, {optimizer.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/top_1_acc": top_1_acc, "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
st = cl
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
# if we are doing beam search, run the first eval too
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
return
# ** eval loop **
if (e + 1 - eval_start_epoch) % eval_epochs == 0 and steps_in_val_epoch > 0:
if getenv("RESET_STEP", 1): train_step.reset() # free the train step memory :(
eval_loss = []
eval_times = []
eval_top_1_acc = []
Tensor.training = False
BEAM.value = EVAL_BEAM
it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False), total=steps_in_val_epoch))
i, proc = 0, data_get(it)
while proc is not None:
GlobalCounters.reset()
st = time.time()
(loss, top_1_acc), proc = eval_step(proc[0], proc[1]), proc[2] # drop inputs, keep cookie
try:
next_proc = data_get(it)
except StopIteration:
next_proc = None
loss, top_1_acc = loss.numpy().item(), top_1_acc.numpy().item() / EVAL_BS
eval_loss.append(loss)
eval_top_1_acc.append(top_1_acc)
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == BENCHMARK: return
et = time.time()
eval_times.append(et - st)
if getenv("RESET_STEP", 1): eval_step.reset()
total_loss = sum(eval_loss) / len(eval_loss)
total_top_1 = sum(eval_top_1_acc) / len(eval_top_1_acc)
total_fw_time = sum(eval_times) / len(eval_times)
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}")
if WANDB:
wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1})
# save model if achieved target
if not achieved and total_top_1 >= target:
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
fn = f"./ckpts/resnet50.safe"
safe_save(get_state_dict(model), fn)
print(f" *** Model saved to {fn} ***")
achieved = True
# stop once achieve the target
break
# checkpoint every time we eval
if getenv("CKPT"):
if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
if WANDB and wandb.run is not None:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe"
else:
fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
def train_retinanet():
# TODO: Retinanet
pass
def train_unet3d():
# TODO: Unet3d
pass
def train_rnnt():
# TODO: RNN-T
pass
def train_bert():
# NOTE: pip install tensorflow, wandb required
from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
from examples.mlperf.helpers import get_mlperf_bert_model, load_from_tf2_ckpt
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
config = {}
BASEDIR = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki")
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
print(f"Training on {GPUS}")
for x in GPUS: Device[x]
seed = config["seed"] = getenv("SEED", 12345)
# ** hyperparameters **
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 4 * len(GPUS)) # FP32 4090: 6 GPUS -> BS24
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 4 * len(GPUS))
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000004166 * BS)
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 4800000 // BS)
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", train_steps // 10)
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
eval_step_freq = config["EVAL_STEP_FREQ"] = int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS) # Round down
save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
decay = config["decay"] = getenv("DECAY", 0.01)
poly_power = config["poly_power"] = getenv("POLY_POWER", 1.0)
target, achieved = getenv("TARGET", 0.72), False
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["TRAIN_BEAM"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
config["EVAL_BEAM"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
Tensor.manual_seed(seed) # seed for weight initialization
model = get_mlperf_bert_model(BASEDIR / "bert_config.json")
# shard weights and initialize in order
for tinygrad_key, x in get_state_dict(model).items():
if init_ckpt and not tinygrad_key.endswith("lm_output.weight"): # lm_output.weight already is word embedding
t = load_from_tf2_ckpt(key=tinygrad_key, ckpt_dir=init_ckpt)
if any(k in tinygrad_key for k in ["intermediate.dense.weight", "output.dense.weight", "clsf_output.weight"]) and "attention" not in tinygrad_key:
t = t.transpose()
elif any(k in tinygrad_key for k in ["self", "output.dense", "clsf_pooler", "lm_transform"]) and "weight" in tinygrad_key:
t = t.reshape(*x.shape).transpose()
elif all(k in tinygrad_key for k in ["self", "bias"]):
t = t.reshape(*x.shape)
x.assign(t).realize().to_(GPUS)
x.realize().to_(GPUS)
parameters = get_parameters(model)
assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
# ** Log hparams **
for key, value in config.items():
print(f'HParam: "{key}": {value}')
# ** Optimizer **
skip_list = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
parameters = [x for x in parameters if x not in set(skip_list)]
optimizer = LAMB(parameters, 1 / warmup_steps, eps=1e-6, wd=decay, adam=False)
optimizer_skip = LAMB(skip_list, 1 / warmup_steps, eps=1e-6, wd=0.0, adam=False)
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
# ** LR scheduler **
scheduler = PolynomialDecayWithWarmup(optimizer, max_lr, 0, train_steps, warmup_steps, power=poly_power)
print(f"Training with batch size {BS} for one epoch with {train_steps} steps")
# ** resume from checkpointing **
start_step = 0
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler, safe_load(ckpt))
start_step = scheduler.epoch_counter.numpy().item()
print(f"resuming from {ckpt} at step {start_step}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
BENCHMARK = getenv("BENCHMARK")
@TinyJit
def train_step(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
loss = lm_loss + clsf_loss
if not getenv('DISABLE_BACKWARD', 0):
optimizer_group.zero_grad()
loss.backward()
optimizer_group.step()
scheduler.step()
return loss.realize()
@TinyJit
def eval_step(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, clsf_logits = model(input_ids, segment_ids, attention_mask, masked_positions)
clsf_predictions = clsf_logits.log_softmax().argmax(-1)
clsf_accuracy = (clsf_predictions == next_sentence_labels).float().mean()
mlm_predictions = lm_logits.log_softmax().argmax(-1)
mask = (masked_lm_weights == 1.0)
mlm_accuracy = (mlm_predictions == masked_lm_ids).where(mask, 0).sum() / mask.float().sum()
lm_loss = lm_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
clsf_loss = clsf_logits.binary_crossentropy_logits(next_sentence_labels)
return {
"masked_lm_accuracy": mlm_accuracy.realize(),
"masked_lm_loss": lm_loss.realize(),
"next_sentence_accuracy": clsf_accuracy.realize(),
"next_sentence_loss": clsf_loss.realize()
}
def data_get(it):
data: dict[str, Tensor] = next(it)
for key in data.keys(): data[key].shard_(GPUS, axis=0)
return data
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
step_times = []
# ** train loop **
wc_start = time.perf_counter()
Tensor.training = True
BEAM.value = TRAIN_BEAM
i, train_data = 0, data_get(train_it)
while train_data is not None and i < train_steps and not achieved:
st = time.perf_counter()
GlobalCounters.reset()
loss = train_step(train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
pt = time.perf_counter()
try:
next_data = data_get(train_it)
except StopIteration:
next_data = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.numpy().item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
train_data, next_data = next_data, None
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * train_steps / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
return
# ** eval loop **
if i % eval_step_freq == 0 or i == 1:
train_step.reset() # free the train step memory :(
eval_loss = []
eval_accuracy = []
eval_times = []
Tensor.training = False
BEAM.value = EVAL_BEAM
for _ in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
eval_data = data_get(eval_it)
GlobalCounters.reset()
st = time.time()
eval_result: dict[str, Tensor] = eval_step(eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"], \
eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
lm_loss, clsf_loss = eval_result["masked_lm_loss"].numpy().item(), eval_result["next_sentence_loss"].numpy().item()
mlm_accuracy, clsf_accuracy = eval_result["masked_lm_accuracy"].numpy().item(), eval_result["next_sentence_accuracy"].numpy().item()
eval_loss.append([lm_loss, clsf_loss])
eval_accuracy.append([mlm_accuracy, clsf_accuracy])
et = time.time()
eval_times.append(et - st)
eval_step.reset()
Tensor.training = True
total_lm_loss = sum(pair[0] for pair in eval_loss) / len(eval_loss)
total_clsf_loss = sum(pair[1] for pair in eval_loss) / len(eval_loss)
total_lm_accuracy = sum(pair[0] for pair in eval_accuracy) / len(eval_accuracy)
total_clsf_accuracy = sum(pair[1] for pair in eval_accuracy) / len(eval_accuracy)
total_fw_time = sum(eval_times) / len(eval_times)
results = f"eval lm loss: {total_lm_loss:.2f}, eval clsf loss: {total_clsf_loss:.2f}, eval lm accuracy: {total_lm_accuracy:.6f}, \
eval clsf accuracy: {total_clsf_accuracy:.2f}, avg eval step time: {total_fw_time:.2f}"
tqdm.write(results)
with open(getenv("EVAL_LOG", "./eval_log.txt"), "a") as file: file.write(results + "\n")
if WANDB:
wandb.log({"eval/lm_loss": total_lm_loss, "eval/clsf_loss": total_clsf_loss, "eval/lm_accuracy": total_lm_accuracy, \
"eval/clsf_accuracy": total_clsf_accuracy, "eval/forward_time": total_fw_time})
# save model if achieved target
if not achieved and total_lm_accuracy >= target:
wc_end = time.perf_counter()
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/bert-large.safe"
safe_save(get_state_dict(model), fn)
print(f" *** Model saved to {fn} ***")
total_seconds = wc_end - wc_start
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
seconds = total_seconds % 60
print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
achieved = True
if getenv("CKPT") and i % save_ckpt_freq == 0:
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
if WANDB and wandb.run is not None:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
else:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler), fn)
ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
while len(ckpt_files) > keep_ckpt_amount:
last = ckpt_files.pop(0)
print(f"Removing old ckpt {last}")
os.remove(os.path.join(ckpt_dir, last))
def train_maskrcnn():
# TODO: Mask RCNN
pass
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()