From 27d899ce9791ed3912b95dfbdc9cff5499e60e8f Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 22 Dec 2025 11:55:46 -0500 Subject: [PATCH] TRAIN=0 to only eval llama (#13804) --- examples/mlperf/model_train.py | 43 +++++++++++++++++----------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 7c2611f7e6..62f0804510 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1438,33 +1438,34 @@ def train_llama3(): iter = get_train_iter() i, sequences_seen = resume_ckpt, 0 for tokens in tqdm(iter, total=SAMPLES//GBS): - t = time.perf_counter() GlobalCounters.reset() - loss, lr = train_step(model, tokens) - loss = loss.float().item() - lr = lr.item() + if getenv("TRAIN", 1): + t = time.perf_counter() + loss, lr = train_step(model, tokens) + loss = loss.float().item() + lr = lr.item() - i += 1 - sequences_seen += tokens.shape[0] + i += 1 + sequences_seen += tokens.shape[0] - sec = time.perf_counter()-t - tqdm.write( - f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, " - f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS") + sec = time.perf_counter()-t + tqdm.write( + f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, " + f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS") - if (fname:=getenv("LOSS_FILE", "")): - with open(fname, "a") as f: - f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n") + if (fname:=getenv("LOSS_FILE", "")): + with open(fname, "a") as f: + f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n") - if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)): - tqdm.write("saving checkpoint") - if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir) - fn = f"{ckpt_dir}/llama3_{i}.safe" - safe_save(get_state_dict(model), fn) + if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)): + tqdm.write("saving checkpoint") + if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir) + fn = f"{ckpt_dir}/llama3_{i}.safe" + safe_save(get_state_dict(model), fn) - tqdm.write("saving optim checkpoint") - fn = f"{ckpt_dir}/llama3_{i}_optim.safe" - safe_save(get_state_dict(scheduler), fn) + tqdm.write("saving optim checkpoint") + fn = f"{ckpt_dir}/llama3_{i}_optim.safe" + safe_save(get_state_dict(scheduler), fn) if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1): tqdm.write(f"evaluating after {sequences_seen} sequences")