mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
feat: resume ckpt (#11970)
This commit is contained in:
@@ -4,7 +4,7 @@ import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW, Profiling
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, load_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam, AdamW
|
||||
|
||||
from extra.lr_scheduler import LRSchedulerGroup
|
||||
@@ -1356,6 +1356,15 @@ def train_llama3():
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
|
||||
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
|
||||
|
||||
if resume_ckpt := getenv("RESUME_CKPT"):
|
||||
fn = f"./ckpts/llama3_{resume_ckpt}.safe"
|
||||
print(f"loading initial checkpoint from {fn}")
|
||||
load_state_dict(model, safe_load(fn), realize=False)
|
||||
|
||||
fn = f"./ckpts/llama3_{resume_ckpt}_optim.safe"
|
||||
print(f"loading optim checkpoint from {fn}")
|
||||
load_state_dict(scheduler, safe_load(fn), realize=False)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(model, tokens:Tensor, grad_acc:int):
|
||||
@@ -1431,26 +1440,30 @@ def train_llama3():
|
||||
return batch_load_llama3(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True)
|
||||
|
||||
iter = get_train_iter()
|
||||
i, sequences_seen = 0, 0
|
||||
i, sequences_seen = resume_ckpt, 0
|
||||
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
||||
t = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss, lr = train_step(model, tokens, grad_acc)
|
||||
loss = loss.float().item()
|
||||
|
||||
i += 1
|
||||
sequences_seen += tokens.shape[0]
|
||||
|
||||
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
|
||||
if getenv("CKPT") and (i % 200 == 0 or i == 10):
|
||||
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
|
||||
tqdm.write("saving checkpoint")
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/llama3_{i}.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
|
||||
i += 1
|
||||
sequences_seen += tokens.shape[0]
|
||||
tqdm.write("saving optim checkpoint")
|
||||
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
|
||||
safe_save(get_state_dict(scheduler), fn)
|
||||
|
||||
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
|
||||
tqdm.write(f"evaluating after {sequences_seen} sequences")
|
||||
|
||||
Reference in New Issue
Block a user