From 825b6a25050554d43bef7448f460758a12f3c7eb Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 30 Jul 2025 13:27:55 -0700 Subject: [PATCH] feat: llama3 dataloader (#11340) --- examples/mlperf/dataloader.py | 213 ++++++++++++++++++++++++++++++++- examples/mlperf/model_eval.py | 30 ++++- examples/mlperf/model_train.py | 35 ++++-- extra/models/llama.py | 4 +- 4 files changed, 267 insertions(+), 15 deletions(-) diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index c01ab48a56..13fe3b9405 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -1,4 +1,6 @@ -import os, random, pickle, queue +import functools +import hashlib +import os, random, pickle, queue, struct, math from typing import List from pathlib import Path from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count @@ -6,6 +8,7 @@ from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu import numpy as np from tinygrad import dtypes, Tensor from tinygrad.helpers import getenv, prod, Context, round_up, tqdm, OSX +from tinygrad.nn.state import TensorIO ### ResNet @@ -510,6 +513,202 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh # happens with BENCHMARK set pass +# llama3 + +class BinIdxDataset: + def __init__(self, base_path:Path): + self.idx_t = Tensor(base_path.with_name(f"{base_path.name}.idx")) + self.idx = TensorIO(self.idx_t) + + # parse idx file + magic = self.idx.read(9) + assert magic == b"MMIDIDX\x00\x00", "invalid index file format" + version, = struct.unpack(" tuple[int, int]: + return self.pointers[idx].item(), self.sizes[idx].item() + + def get(self, idx, offset:int=0, length:int|None=None): + ptr, size = self._index(idx) + if length is None: length = size - offset + ptr += offset * self.dtype.itemsize + return self.bin_t[ptr:ptr+length*self.dtype.itemsize].bitcast(self.dtype).to(None) + +# https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/datasets.html +class GPTDataset: + def __init__(self, base_path:Path, samples:int, seqlen:int, seed:int, shuffle:bool): + self.samples, self.seqlen = samples, seqlen + self.shuffle = shuffle + self.rng = np.random.RandomState(seed) + + self.indexed_dataset = BinIdxDataset(base_path) + + # check for cache + cache_hash = hashlib.sha256(f"{samples}:{seqlen}:{seed}:{shuffle}".encode()).hexdigest() + cache_path = base_path.with_name(f"{base_path.name}.{cache_hash}.index_cache") + if cache_path.exists(): + with open(cache_path, "rb") as f: + self.doc_idx, self.sample_idx, self.shuffle_idx = pickle.load(f) + else: + self.doc_idx = self._build_doc_idx() + self.sample_idx = self._build_sample_idx() + self.shuffle_idx = self._build_shuffle_idx() + # save cache + with open(cache_path, "wb") as f: + pickle.dump((self.doc_idx, self.sample_idx, self.shuffle_idx), f) + + def __getitem__(self, idx): + if idx is None: + text = self._get(0) + else: + text = self._get(idx) + + return text + + def _get(self, idx): + idx = self.shuffle_idx[idx] + + doc_idx_beg, doc_idx_beg_offset = self.sample_idx[idx] + doc_idx_end, doc_idx_end_offset = self.sample_idx[idx + 1] + + doc_ids, sample_parts = [], [] + + if doc_idx_beg == doc_idx_end: + doc_ids.append(self.doc_idx[doc_idx_beg]) + + sample_parts.append( + self.indexed_dataset.get( + int(self.doc_idx[doc_idx_beg]), offset=int(doc_idx_beg_offset), length=int(doc_idx_end_offset - doc_idx_beg_offset + 1))) + else: + for i in range(doc_idx_beg, doc_idx_end + 1): + doc_ids.append(self.doc_idx[i]) + + offset = 0 if i > doc_idx_beg else doc_idx_beg_offset + length = None if i < doc_idx_end else int(doc_idx_end_offset + 1) + sample_parts.append(self.indexed_dataset.get(int(self.doc_idx[i]), offset=int(offset), length=length)) + + # concat all parts + text = Tensor.cat(*sample_parts) + + return text + + @functools.cached_property + def tokens_per_epoch(self) -> int: + return sum(self.indexed_dataset.sizes.tolist()) + + @functools.cached_property + def num_epochs(self) -> int: + # we need enough epochs to cover the requested amount of tokens + num_epochs = 1 + num_tokens = self.tokens_per_epoch + while num_tokens < self.samples * self.seqlen: + num_epochs += 1 + num_tokens += self.tokens_per_epoch + return num_epochs + + # https://github.com/NVIDIA/Megatron-LM/blob/94bd476bd840c2fd4c3ebfc7448c2af220f4832b/megatron/core/datasets/gpt_dataset.py#L558 + def _build_doc_idx(self): + doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1] + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + if self.shuffle: self.rng.shuffle(doc_idx) + return doc_idx + + def _build_sample_idx(self): + sample_idx = np.empty((self.samples + 1, 2), dtype=np.int32) + + sample_idx_idx, doc_idx_idx, doc_offset = 0, 0, 0 + sample_idx[sample_idx_idx, 0], sample_idx[sample_idx_idx, 1] = doc_idx_idx, doc_offset + sample_idx_idx += 1 + + for _ in tqdm(range(1, self.samples + 1)): + remaining_seqlen = self.seqlen + 1 + while remaining_seqlen > 0: + doc_idx = int(self.doc_idx[doc_idx_idx]) + doc_len = self.indexed_dataset.sizes[doc_idx].item() - doc_offset + remaining_seqlen -= doc_len + if remaining_seqlen <= 0: + doc_offset += remaining_seqlen + doc_len - 1 + remaining_seqlen = 0 + else: + if doc_idx_idx == len(self.doc_idx) - 1: + assert sample_idx_idx == self.samples + doc_idx = int(self.doc_idx[doc_idx_idx]) + doc_offset = self.indexed_dataset.sizes[doc_idx].item() - 1 + break + doc_idx_idx += 1 + doc_offset = 0 + + sample_idx[sample_idx_idx, 0], sample_idx[sample_idx_idx, 1] = doc_idx_idx, doc_offset + sample_idx_idx += 1 + + return sample_idx + + def _build_shuffle_idx(self): + shuffle_idx = np.arange(self.samples, dtype=np.int32) + if self.shuffle: self.rng.shuffle(shuffle_idx) + return shuffle_idx + +class BlendedGPTDataset: + def __init__(self, paths:list[Path], weights:list[float], samples:int, seqlen:int, seed:int, shuffle:bool): + self.seed = seed + + # normalize weights + total_weight = sum(weights) + self.weights = [w / total_weight for w in weights] + + self.samples = samples + surplus = 0.005 + samples_per_blend = [math.ceil(math.ceil(self.samples * w) * (1 + surplus)) for w in self.weights] + + self.datasets = [GPTDataset(path, samples_per_blend[i], seqlen, seed + i, shuffle) for i,path in enumerate(paths)] + + def get(self, idx:int): + tokens = self.datasets[0][idx] + return tokens + +def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True): + if val: + dataset = BlendedGPTDataset([ + base_dir / "validation" / "c4-validationn-91205-samples.en_text_document", + ], [ + 1.0 + ], samples, seqlen, seed, False) + else: + dataset = BlendedGPTDataset([ + base_dir / "c4-train.en_6_text_document", + base_dir / "c4-train.en_7_text_document", + ], [ + 1.0, 1.0 + ], samples, seqlen, seed, True) + + for b in range(math.ceil(samples / bs)): + batch = [] + for i in range(bs): + tokens = dataset.get(b * bs + i) + batch.append(tokens) + yield Tensor.stack(batch, dim=0) + if __name__ == "__main__": def load_unet3d(val): assert not val, "validation set is not supported due to different sizes on inputs" @@ -538,6 +737,18 @@ if __name__ == "__main__": for x in batch_load_retinanet(dataset, val, base_dir): pbar.update(x[0].shape[0]) + def load_llama3(val): + bs = 24 + samples = 5760 if val else 1_200_000 + seqlen = 512 + + max_, min_ = 0, math.inf + for tokens in tqdm(batch_load_llama3(bs, samples, seqlen, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=5760, val=bool(val)), total=samples//bs): + max_ = max(max_, tokens.shape[1]) + min_ = min(min_, tokens.shape[1]) + print(f"max seq length: {max_}") + print(f"min seq length: {min_}") + load_fn_name = f"load_{getenv('MODEL', 'resnet')}" if load_fn_name in globals(): globals()[load_fn_name](getenv("VAL", 1)) diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index fa3ca9d7fe..091f9456ec 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -1,4 +1,4 @@ -import time +import time, math start = time.perf_counter() from pathlib import Path import numpy as np @@ -241,6 +241,34 @@ def eval_mrcnn(): evaluate_predictions_on_coco(bbox_output, iou_type='bbox') evaluate_predictions_on_coco(mask_output, iou_type='segm') +def eval_llama3(): + from extra.models.llama import Transformer + from examples.llama3 import MODEL_PARAMS + from tinygrad.helpers import tqdm + + bs = 4 + sequence_length = 512 + + model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=sequence_length, jit=False, disable_kv_cache=True) + + @TinyJit + def eval_step(model, tokens): + logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan) + loss = logits.sparse_categorical_crossentropy(tokens[:, 1:]) + return loss.flatten() + + from examples.mlperf.dataloader import batch_load_llama3 + iter = batch_load_llama3(bs, 5760, sequence_length, Path(getenv("BASEDIR", "/raid/datasets/c4/")), True) + + losses = [] + for tokens in tqdm(iter, total=5760//bs): + GlobalCounters.reset() + losses += eval_step(model, tokens).tolist() + tqdm.write(f"loss: {np.mean(losses)}") + + log_perplexity = Tensor(losses).mean() + print(f"Log Perplexity: {log_perplexity.item()}") + if __name__ == "__main__": # inference only Tensor.training = False diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index b2697deec5..4a4b12e3c6 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1290,9 +1290,12 @@ def train_llama3(): from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup config = {} - BS = config["BS"] = getenv("BS", 4) + BS = config["BS"] = getenv("BS", 16) grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1) GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc + SEED = config["SEED"] = getenv("SEED", 5760) + SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 1_200_000) + SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192) opt_adamw_beta_1 = 0.9 opt_adamw_beta_2 = 0.95 @@ -1300,7 +1303,6 @@ def train_llama3(): opt_adamw_weight_decay = 0.1 opt_gradient_clip_norm = 1.0 - sequence_length = 8192 opt_learning_rate_warmup_steps = getenv("WARMUP_STEPS", math.ceil(8000 * 1152 / GBS)) opt_learning_rate_decay_steps = getenv("DECAY_STEPS", math.ceil(1_200_000 * 1152 / GBS) - opt_learning_rate_warmup_steps) opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark @@ -1308,7 +1310,7 @@ def train_llama3(): # TODO: confirm weights are in bf16 # vocab_size from the mixtral tokenizer - model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=sequence_length, jit=False, disable_kv_cache=True) + model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=SEQLEN, jit=False, disable_kv_cache=True) optim = AdamW(get_parameters(model), lr=0.0, b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay) @@ -1316,10 +1318,10 @@ def train_llama3(): @TinyJit @Tensor.train() - def train_step(model, x, y): + def train_step(model, tokens): optim.zero_grad() - logits:Tensor = model(x, start_pos=0, temperature=math.nan) - loss = logits.cross_entropy(y) + logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan) + loss = logits.sparse_categorical_crossentropy(tokens[:, 1:]) loss.backward() # L2 norm grad clip @@ -1341,18 +1343,29 @@ def train_llama3(): return loss, lr # overfitting this example should give cross_entropy log(BS) - fake_input = Tensor([list(range(getenv("SEQLEN", 10)))], dtype="int16").expand(BS, -1) - fake_label = Tensor(list(range(BS)), dtype="int16") + from examples.mlperf.dataloader import batch_load_llama3 + iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=False) - for _ in range(100): + i = 0 + for tokens in tqdm(iter, total=SAMPLES//BS): GlobalCounters.reset() - loss, lr = train_step(model, fake_input, fake_label) + loss, lr = train_step(model, tokens) # BS=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=8B WARMUP_STEPS=2 DECAY_STEPS=300 PYTHONPATH=. AMD=1 MODEL=llama3 python3 examples/mlperf/model_train.py # uses 43% ~= 83GB # 8B bf16 = 16GB. model + grad + optim m and v = 64GB # TODO: this OOM # BS=1 SEQLEN=4000 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=8B WARMUP_STEPS=2 DECAY_STEPS=300 PYTHONPATH=. AMD=1 MODEL=llama3 python3 examples/mlperf/model_train.py - print(loss.item(), lr.item(), f"{GlobalCounters.global_mem//10**9=}") + # above as tqdm.write f-string + tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used") + with open("loss.txt", "a") as f: + f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n") + + if i % 200 == 0 or i == 10: + tqdm.write("saving checkpoint") + if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir) + fn = f"{ckpt_dir}/{i}.safe" + safe_save(get_state_dict(model), fn) + i += 1 if __name__ == "__main__": multiprocessing.set_start_method('spawn') diff --git a/extra/models/llama.py b/extra/models/llama.py index 16e1aea0bd..ecee88f64a 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -185,10 +185,10 @@ class Transformer: mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1) if seqlen > 1 else None for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) - logits = self.output(self.norm(h)).float()[:, -1, :] + logits = self.output(self.norm(h)).float() if math.isnan(temperature): return logits - return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p) + return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p) def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0): # TODO: better way to handle the first call v.s. the rest?