llama wandb logging (#13822)

This commit is contained in:
chenyu
2025-12-24 10:24:59 -05:00
committed by GitHub
parent e3a646dce3
commit 903753c60c

View File

@@ -1314,6 +1314,13 @@ def train_llama3():
opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark
opt_end_learning_rate = getenv("END_LR", 8e-7)
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args, project="MLPerf-LLaMA3")
model_params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
# vocab_size from the mixtral tokenizer
if not SMALL: model_params |= {"vocab_size": 32000}
@@ -1449,13 +1456,17 @@ def train_llama3():
sequences_seen += tokens.shape[0]
sec = time.perf_counter()-t
mem_gb = GlobalCounters.mem_used / 1e9
gflops = GlobalCounters.global_ops / 1e9 / sec
tqdm.write(
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {mem_gb:.2f} GB used, {gflops:9.2f} GFLOPS")
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
f.write(f"{i} {loss:.4f} {lr:.12f} {mem_gb:.2f}\n")
if WANDB:
wandb.log({"lr": lr, "train/loss": loss, "train/step_time": sec, "train/GFLOPS": gflops, "train/sequences_seen": sequences_seen})
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
tqdm.write("saving checkpoint")
@@ -1481,6 +1492,9 @@ def train_llama3():
tqdm.write(f"eval log perplexity: {log_perplexity:.4f}")
if WANDB:
wandb.log({"eval/log_perplexity": log_perplexity, "eval/sequences_seen": sequences_seen})
if log_perplexity < EVAL_TARGET:
tqdm.write(f"target achieved after {sequences_seen} sequences")
if getenv("CKPT"):