mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
llama3 eval train (#11706)
This commit is contained in:
@@ -1297,6 +1297,9 @@ def train_llama3():
|
||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
|
||||
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
|
||||
EVAL_FREQ = config["EVAL_FREQ"] = getenv("EVAL_FREQ", 46080)
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
|
||||
EVAL_TARGET = config["EVAL_TARGET"] = getenv("EVAL_TARGET", 5.6)
|
||||
|
||||
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
# trains to 7
|
||||
@@ -1384,16 +1387,40 @@ def train_llama3():
|
||||
loss.realize(lr)
|
||||
return loss, lr
|
||||
|
||||
if getenv("FAKEDATA", 0):
|
||||
def fake_data():
|
||||
for _ in range(SAMPLES // GBS):
|
||||
yield Tensor.randint(GBS, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT)
|
||||
iter = fake_data()
|
||||
else:
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||
@TinyJit
|
||||
@Tensor.train(False)
|
||||
def eval_step(model, tokens:Tensor):
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
tokens = tokens.shard(device, 0)
|
||||
if (MP := getenv("MP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
||||
tokens = tokens.shard(device)
|
||||
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
return loss.flatten().float()
|
||||
|
||||
i = 0
|
||||
# ** data iters **
|
||||
def fake_data(bs, samples):
|
||||
for _ in range(samples // bs):
|
||||
yield Tensor.randint(bs, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT)
|
||||
|
||||
def get_train_iter():
|
||||
if getenv("FAKEDATA", 0):
|
||||
return fake_data(GBS, SAMPLES)
|
||||
else:
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
return batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||
|
||||
def get_eval_iter():
|
||||
if getenv("FAKEDATA", 0):
|
||||
return fake_data(EVAL_BS, 5760)
|
||||
else:
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
return batch_load_llama3(EVAL_BS, 5760, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=True)
|
||||
|
||||
iter = get_train_iter()
|
||||
i, sequences_seen = 0, 0
|
||||
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
||||
t = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
@@ -1408,9 +1435,33 @@ def train_llama3():
|
||||
if getenv("CKPT") and (i % 200 == 0 or i == 10):
|
||||
tqdm.write("saving checkpoint")
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/{i}.safe"
|
||||
fn = f"{ckpt_dir}/llama3_{i}.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
|
||||
i += 1
|
||||
sequences_seen += tokens.shape[0]
|
||||
|
||||
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
|
||||
tqdm.write(f"evaluating after {sequences_seen} sequences")
|
||||
|
||||
# run eval
|
||||
eval_losses = []
|
||||
eval_iter = get_eval_iter()
|
||||
tqdm.write(f"evaluating {5760//EVAL_BS} batches of {EVAL_BS} sequences")
|
||||
|
||||
for tokens in tqdm(eval_iter, total=5760//EVAL_BS):
|
||||
eval_losses += eval_step(model, tokens).tolist()
|
||||
log_perplexity = Tensor(eval_losses).mean().float().item()
|
||||
|
||||
tqdm.write(f"eval log perplexity: {log_perplexity:.4f}")
|
||||
|
||||
if log_perplexity < EVAL_TARGET:
|
||||
tqdm.write(f"target achieved after {sequences_seen} sequences")
|
||||
if getenv("CKPT"):
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/llama3.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
|
||||
Reference in New Issue
Block a user