mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
update llama3 (#11446)
`LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py` trained to 7
This commit is contained in:
@@ -1294,9 +1294,12 @@ def train_llama3():
|
||||
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
||||
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
||||
SEED = config["SEED"] = getenv("SEED", 5760)
|
||||
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 1_200_000)
|
||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
|
||||
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000)
|
||||
|
||||
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
# trains to 7
|
||||
|
||||
opt_adamw_beta_1 = 0.9
|
||||
opt_adamw_beta_2 = 0.95
|
||||
@@ -1350,17 +1353,13 @@ def train_llama3():
|
||||
for tokens in tqdm(iter, total=SAMPLES//BS):
|
||||
GlobalCounters.reset()
|
||||
loss, lr = train_step(model, tokens)
|
||||
# BS=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=8B WARMUP_STEPS=2 DECAY_STEPS=300 PYTHONPATH=. AMD=1 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
# uses 43% ~= 83GB
|
||||
# 8B bf16 = 16GB. model + grad + optim m and v = 64GB
|
||||
# TODO: this OOM
|
||||
# BS=1 SEQLEN=4000 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=8B WARMUP_STEPS=2 DECAY_STEPS=300 PYTHONPATH=. AMD=1 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
# above as tqdm.write f-string
|
||||
tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used")
|
||||
with open("loss.txt", "a") as f:
|
||||
f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
|
||||
if i % 200 == 0 or i == 10:
|
||||
if getenv("CKPT") and (i % 200 == 0 or i == 10):
|
||||
tqdm.write("saving checkpoint")
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/{i}.safe"
|
||||
|
||||
Reference in New Issue
Block a user