feat: resume ckpt (#11970)

This commit is contained in:
wozeparrot
2025-09-02 15:47:48 -07:00
committed by GitHub
parent 1b73993521
commit d16cc6c012

View File

@@ -4,7 +4,7 @@ import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW, Profiling
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
from extra.lr_scheduler import LRSchedulerGroup
@@ -1356,6 +1356,15 @@ def train_llama3():
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
if resume_ckpt := getenv("RESUME_CKPT"):
fn = f"./ckpts/llama3_{resume_ckpt}.safe"
print(f"loading initial checkpoint from {fn}")
load_state_dict(model, safe_load(fn), realize=False)
fn = f"./ckpts/llama3_{resume_ckpt}_optim.safe"
print(f"loading optim checkpoint from {fn}")
load_state_dict(scheduler, safe_load(fn), realize=False)
@TinyJit
@Tensor.train()
def train_step(model, tokens:Tensor, grad_acc:int):
@@ -1431,26 +1440,30 @@ def train_llama3():
return batch_load_llama3(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True)
iter = get_train_iter()
i, sequences_seen = 0, 0
i, sequences_seen = resume_ckpt, 0
for tokens in tqdm(iter, total=SAMPLES//GBS):
t = time.perf_counter()
GlobalCounters.reset()
loss, lr = train_step(model, tokens, grad_acc)
loss = loss.float().item()
i += 1
sequences_seen += tokens.shape[0]
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
if getenv("CKPT") and (i % 200 == 0 or i == 10):
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
tqdm.write("saving checkpoint")
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/llama3_{i}.safe"
safe_save(get_state_dict(model), fn)
i += 1
sequences_seen += tokens.shape[0]
tqdm.write("saving optim checkpoint")
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
safe_save(get_state_dict(scheduler), fn)
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
tqdm.write(f"evaluating after {sequences_seen} sequences")