update llama3 (#11446)

`LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py` trained to 7
This commit is contained in:
chenyu
2025-07-30 16:34:21 -07:00
committed by GitHub
parent 5fb975351a
commit e300451f3a

View File

@@ -1294,9 +1294,12 @@ def train_llama3():
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
SEED = config["SEED"] = getenv("SEED", 5760)
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 1_200_000)
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000)
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
# trains to 7
opt_adamw_beta_1 = 0.9
opt_adamw_beta_2 = 0.95
@@ -1350,17 +1353,13 @@ def train_llama3():
for tokens in tqdm(iter, total=SAMPLES//BS):
GlobalCounters.reset()
loss, lr = train_step(model, tokens)
# BS=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=8B WARMUP_STEPS=2 DECAY_STEPS=300 PYTHONPATH=. AMD=1 MODEL=llama3 python3 examples/mlperf/model_train.py
# uses 43% ~= 83GB
# 8B bf16 = 16GB. model + grad + optim m and v = 64GB
# TODO: this OOM
# BS=1 SEQLEN=4000 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=8B WARMUP_STEPS=2 DECAY_STEPS=300 PYTHONPATH=. AMD=1 MODEL=llama3 python3 examples/mlperf/model_train.py
# above as tqdm.write f-string
tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used")
with open("loss.txt", "a") as f:
f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
if i % 200 == 0 or i == 10:
if getenv("CKPT") and (i % 200 == 0 or i == 10):
tqdm.write("saving checkpoint")
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/{i}.safe"