llama3 eval train (#11706)

This commit is contained in:
wozeparrot
2025-08-20 19:56:35 -04:00
committed by GitHub
parent dbd3b67657
commit b979162c5d

View File

@@ -1297,6 +1297,9 @@ def train_llama3():
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
EVAL_FREQ = config["EVAL_FREQ"] = getenv("EVAL_FREQ", 46080)
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
EVAL_TARGET = config["EVAL_TARGET"] = getenv("EVAL_TARGET", 5.6)
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
# trains to 7
@@ -1384,16 +1387,40 @@ def train_llama3():
loss.realize(lr)
return loss, lr
if getenv("FAKEDATA", 0):
def fake_data():
for _ in range(SAMPLES // GBS):
yield Tensor.randint(GBS, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT)
iter = fake_data()
else:
from examples.mlperf.dataloader import batch_load_llama3
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
@TinyJit
@Tensor.train(False)
def eval_step(model, tokens:Tensor):
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
tokens = tokens.shard(device, 0)
if (MP := getenv("MP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
tokens = tokens.shard(device)
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
return loss.flatten().float()
i = 0
# ** data iters **
def fake_data(bs, samples):
for _ in range(samples // bs):
yield Tensor.randint(bs, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT)
def get_train_iter():
if getenv("FAKEDATA", 0):
return fake_data(GBS, SAMPLES)
else:
from examples.mlperf.dataloader import batch_load_llama3
return batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
def get_eval_iter():
if getenv("FAKEDATA", 0):
return fake_data(EVAL_BS, 5760)
else:
from examples.mlperf.dataloader import batch_load_llama3
return batch_load_llama3(EVAL_BS, 5760, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=True)
iter = get_train_iter()
i, sequences_seen = 0, 0
for tokens in tqdm(iter, total=SAMPLES//GBS):
t = time.perf_counter()
GlobalCounters.reset()
@@ -1408,9 +1435,33 @@ def train_llama3():
if getenv("CKPT") and (i % 200 == 0 or i == 10):
tqdm.write("saving checkpoint")
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/{i}.safe"
fn = f"{ckpt_dir}/llama3_{i}.safe"
safe_save(get_state_dict(model), fn)
i += 1
sequences_seen += tokens.shape[0]
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
tqdm.write(f"evaluating after {sequences_seen} sequences")
# run eval
eval_losses = []
eval_iter = get_eval_iter()
tqdm.write(f"evaluating {5760//EVAL_BS} batches of {EVAL_BS} sequences")
for tokens in tqdm(eval_iter, total=5760//EVAL_BS):
eval_losses += eval_step(model, tokens).tolist()
log_perplexity = Tensor(eval_losses).mean().float().item()
tqdm.write(f"eval log perplexity: {log_perplexity:.4f}")
if log_perplexity < EVAL_TARGET:
tqdm.write(f"target achieved after {sequences_seen} sequences")
if getenv("CKPT"):
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/llama3.safe"
safe_save(get_state_dict(model), fn)
break
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')