TRAIN=0 to only eval llama (#13804)

This commit is contained in:
chenyu
2025-12-22 11:55:46 -05:00
committed by GitHub
parent 39d962106f
commit 27d899ce97

View File

@@ -1438,33 +1438,34 @@ def train_llama3():
iter = get_train_iter()
i, sequences_seen = resume_ckpt, 0
for tokens in tqdm(iter, total=SAMPLES//GBS):
t = time.perf_counter()
GlobalCounters.reset()
loss, lr = train_step(model, tokens)
loss = loss.float().item()
lr = lr.item()
if getenv("TRAIN", 1):
t = time.perf_counter()
loss, lr = train_step(model, tokens)
loss = loss.float().item()
lr = lr.item()
i += 1
sequences_seen += tokens.shape[0]
i += 1
sequences_seen += tokens.shape[0]
sec = time.perf_counter()-t
tqdm.write(
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
sec = time.perf_counter()-t
tqdm.write(
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
tqdm.write("saving checkpoint")
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/llama3_{i}.safe"
safe_save(get_state_dict(model), fn)
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
tqdm.write("saving checkpoint")
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/llama3_{i}.safe"
safe_save(get_state_dict(model), fn)
tqdm.write("saving optim checkpoint")
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
safe_save(get_state_dict(scheduler), fn)
tqdm.write("saving optim checkpoint")
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
safe_save(get_state_dict(scheduler), fn)
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
tqdm.write(f"evaluating after {sequences_seen} sequences")