From 7e68045fb2129f56b4e25e28ce2882ca28874976 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Sun, 31 Aug 2025 13:41:47 -0700 Subject: [PATCH] feat: small llama3 training (#11829) --- examples/mlperf/dataloader.py | 21 +++++++++++++++++++ examples/mlperf/model_eval.py | 38 +++++++++++++++++++++++++--------- examples/mlperf/model_train.py | 25 +++++++++++++++------- 3 files changed, 67 insertions(+), 17 deletions(-) diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index c82c42410d..09fb191539 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -758,6 +758,27 @@ def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0 batch.append(tokens) yield Tensor.stack(batch, dim=0) +def batch_load_llama3_small(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True): + if val: + dataset = BlendedGPTDataset([ + base_dir / "c4-validation-91205-samples.en_text_document", + ], [ + 1.0 + ], samples, seqlen, seed, False) + else: + dataset = BlendedGPTDataset([ + base_dir / "c4-train.en_6_text_document", + ], [ + 1.0 + ], samples, seqlen, seed, True) + + for b in range(math.ceil(samples / bs)): + batch = [] + for i in range(bs): + tokens = dataset.get(b * bs + i) + batch.append(tokens) + yield Tensor.stack(batch, dim=0) + if __name__ == "__main__": def load_unet3d(val): assert not val, "validation set is not supported due to different sizes on inputs" diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index 091f9456ec..b71c290a08 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -243,31 +243,49 @@ def eval_mrcnn(): def eval_llama3(): from extra.models.llama import Transformer - from examples.llama3 import MODEL_PARAMS + from examples.llama3 import MODEL_PARAMS, load, convert_from_huggingface from tinygrad.helpers import tqdm - bs = 4 - sequence_length = 512 + BASEDIR = Path(getenv("BASEDIR", "/raid/datasets/c4/")) + BS = getenv("BS", 4) + SMALL = getenv("SMALL", 0) + SEQLEN = getenv("SEQLEN", 8192) + MODEL_PATH = Path(getenv("MODEL_PATH", "/raid/weights/llama31_8b/")) - model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=sequence_length, jit=False, disable_kv_cache=True) + params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"] + params = params | {"vocab_size": 32000} if not SMALL else params + if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers + model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True) + + # load weights + weights = load(str(MODEL_PATH / "model.safetensors.index.json")) + if "model.embed_tokens.weight" in weights: + print("converting from huggingface format") + weights = convert_from_huggingface(weights, params["n_layers"], params["n_heads"], params["n_kv_heads"]) + + load_state_dict(model, weights, strict=False, consume=True) @TinyJit def eval_step(model, tokens): logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan) loss = logits.sparse_categorical_crossentropy(tokens[:, 1:]) - return loss.flatten() + return loss.flatten().float() - from examples.mlperf.dataloader import batch_load_llama3 - iter = batch_load_llama3(bs, 5760, sequence_length, Path(getenv("BASEDIR", "/raid/datasets/c4/")), True) + if SMALL: + from examples.mlperf.dataloader import batch_load_llama3_small + iter = batch_load_llama3_small(BS, 5760, SEQLEN, BASEDIR, val=True) + else: + from examples.mlperf.dataloader import batch_load_llama3 + iter = batch_load_llama3(BS, 5760, SEQLEN, BASEDIR, val=True) losses = [] - for tokens in tqdm(iter, total=5760//bs): + for tokens in tqdm(iter, total=5760//BS): GlobalCounters.reset() losses += eval_step(model, tokens).tolist() tqdm.write(f"loss: {np.mean(losses)}") - log_perplexity = Tensor(losses).mean() - print(f"Log Perplexity: {log_perplexity.item()}") + log_perplexity = np.mean(losses) + print(f"Log Perplexity: {log_perplexity}") if __name__ == "__main__": # inference only diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index f3bab81fba..586e93e827 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1290,12 +1290,14 @@ def train_llama3(): from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup config = {} + BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/")) BS = config["BS"] = getenv("BS", 16) grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1) GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc SEED = config["SEED"] = getenv("SEED", 5760) SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192) TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0) + SMALL = config["SMALL"] = getenv("SMALL", 0) SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152) EVAL_FREQ = config["EVAL_FREQ"] = getenv("EVAL_FREQ", 46080) EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16) @@ -1317,7 +1319,8 @@ def train_llama3(): # TODO: confirm weights are in bf16 # vocab_size from the mixtral tokenizer - params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000} + params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"] + params = params | {"vocab_size": 32000} if not SMALL else params if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True) @@ -1403,21 +1406,29 @@ def train_llama3(): # ** data iters ** def fake_data(bs, samples): for _ in range(samples // bs): - yield Tensor.randint(bs, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT) + yield Tensor.randint(bs, SEQLEN + 1, low=0, high=params["vocab_size"], dtype=dtypes.int32, device=Device.DEFAULT) def get_train_iter(): if getenv("FAKEDATA", 0): return fake_data(GBS, SAMPLES) else: - from examples.mlperf.dataloader import batch_load_llama3 - return batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL)) + if SMALL: + from examples.mlperf.dataloader import batch_load_llama3_small + return batch_load_llama3_small(GBS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL)) + else: + from examples.mlperf.dataloader import batch_load_llama3 + return batch_load_llama3(GBS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL)) def get_eval_iter(): if getenv("FAKEDATA", 0): return fake_data(EVAL_BS, 5760) else: - from examples.mlperf.dataloader import batch_load_llama3 - return batch_load_llama3(EVAL_BS, 5760, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=True) + if SMALL: + from examples.mlperf.dataloader import batch_load_llama3_small + return batch_load_llama3_small(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True) + else: + from examples.mlperf.dataloader import batch_load_llama3 + return batch_load_llama3(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True) iter = get_train_iter() i, sequences_seen = 0, 0 @@ -1426,7 +1437,7 @@ def train_llama3(): GlobalCounters.reset() loss, lr = train_step(model, tokens, grad_acc) loss = loss.float().item() - # above as tqdm.write f-string + tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s") if (fname:=getenv("LOSS_FILE", "")): with open(fname, "a") as f: