mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
llama wandb logging (#13822)
This commit is contained in:
@@ -1314,6 +1314,13 @@ def train_llama3():
|
||||
opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark
|
||||
opt_end_learning_rate = getenv("END_LR", 8e-7)
|
||||
|
||||
# ** init wandb **
|
||||
WANDB = getenv("WANDB")
|
||||
if WANDB:
|
||||
import wandb
|
||||
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
|
||||
wandb.init(config=config, **wandb_args, project="MLPerf-LLaMA3")
|
||||
|
||||
model_params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
|
||||
# vocab_size from the mixtral tokenizer
|
||||
if not SMALL: model_params |= {"vocab_size": 32000}
|
||||
@@ -1449,13 +1456,17 @@ def train_llama3():
|
||||
sequences_seen += tokens.shape[0]
|
||||
|
||||
sec = time.perf_counter()-t
|
||||
mem_gb = GlobalCounters.mem_used / 1e9
|
||||
gflops = GlobalCounters.global_ops / 1e9 / sec
|
||||
tqdm.write(
|
||||
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
|
||||
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
|
||||
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {mem_gb:.2f} GB used, {gflops:9.2f} GFLOPS")
|
||||
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
f.write(f"{i} {loss:.4f} {lr:.12f} {mem_gb:.2f}\n")
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"lr": lr, "train/loss": loss, "train/step_time": sec, "train/GFLOPS": gflops, "train/sequences_seen": sequences_seen})
|
||||
|
||||
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
|
||||
tqdm.write("saving checkpoint")
|
||||
@@ -1481,6 +1492,9 @@ def train_llama3():
|
||||
|
||||
tqdm.write(f"eval log perplexity: {log_perplexity:.4f}")
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/log_perplexity": log_perplexity, "eval/sequences_seen": sequences_seen})
|
||||
|
||||
if log_perplexity < EVAL_TARGET:
|
||||
tqdm.write(f"target achieved after {sequences_seen} sequences")
|
||||
if getenv("CKPT"):
|
||||
|
||||
Reference in New Issue
Block a user