From d16cc6c0123c81db77e1f7c1032ade88edfa5f05 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Tue, 2 Sep 2025 15:47:48 -0700 Subject: [PATCH] feat: resume ckpt (#11970) --- examples/mlperf/model_train.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 2164a7f8c5..88d1cc9c79 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -4,7 +4,7 @@ import multiprocessing from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW, Profiling -from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save +from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW from extra.lr_scheduler import LRSchedulerGroup @@ -1356,6 +1356,15 @@ def train_llama3(): b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay) scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps) + if resume_ckpt := getenv("RESUME_CKPT"): + fn = f"./ckpts/llama3_{resume_ckpt}.safe" + print(f"loading initial checkpoint from {fn}") + load_state_dict(model, safe_load(fn), realize=False) + + fn = f"./ckpts/llama3_{resume_ckpt}_optim.safe" + print(f"loading optim checkpoint from {fn}") + load_state_dict(scheduler, safe_load(fn), realize=False) + @TinyJit @Tensor.train() def train_step(model, tokens:Tensor, grad_acc:int): @@ -1431,26 +1440,30 @@ def train_llama3(): return batch_load_llama3(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True) iter = get_train_iter() - i, sequences_seen = 0, 0 + i, sequences_seen = resume_ckpt, 0 for tokens in tqdm(iter, total=SAMPLES//GBS): t = time.perf_counter() GlobalCounters.reset() loss, lr = train_step(model, tokens, grad_acc) loss = loss.float().item() + i += 1 + sequences_seen += tokens.shape[0] + tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s") if (fname:=getenv("LOSS_FILE", "")): with open(fname, "a") as f: f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n") - if getenv("CKPT") and (i % 200 == 0 or i == 10): + if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)): tqdm.write("saving checkpoint") if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir) fn = f"{ckpt_dir}/llama3_{i}.safe" safe_save(get_state_dict(model), fn) - i += 1 - sequences_seen += tokens.shape[0] + tqdm.write("saving optim checkpoint") + fn = f"{ckpt_dir}/llama3_{i}_optim.safe" + safe_save(get_state_dict(scheduler), fn) if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1): tqdm.write(f"evaluating after {sequences_seen} sequences")