From da1cb6a9ecf7c4ccc12b0929267419eaf99fdbbb Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 24 Dec 2025 17:42:08 -0500 Subject: [PATCH] update llama dataloader (#13825) separate creating dataset from itererating over the dataset to not create eval data for each eval --- examples/mlperf/dataloader.py | 54 ++++++++++------------------------ examples/mlperf/model_eval.py | 9 ++---- examples/mlperf/model_train.py | 25 +++++++--------- 3 files changed, 30 insertions(+), 58 deletions(-) diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 67eae92ce7..9a2fbcec69 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -763,48 +763,26 @@ class BlendedGPTDataset: return dataset_idx, dataset_sample_idx -def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True): +def get_llama3_dataset(samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True, small:bool=False) -> BlendedGPTDataset: + if small: + if val: + return BlendedGPTDataset( + [base_dir / "c4-validation-91205-samples.en_text_document"], [1.0], samples, seqlen, seed, shuffle=False) + return BlendedGPTDataset( + [base_dir / "c4-train.en_6_text_document"], [1.0], samples, seqlen, seed, shuffle=True) if val: - dataset = BlendedGPTDataset([ - base_dir / "validation" / "c4-validationn-91205-samples.en_text_document", - ], [ - 1.0 - ], samples, seqlen, seed, False) - else: - dataset = BlendedGPTDataset([ - base_dir / "c4-train.en_6_text_document", - base_dir / "c4-train.en_7_text_document", - ], [ - 1.0, 1.0 - ], samples, seqlen, seed, True) + return BlendedGPTDataset( + [base_dir / "validation" / "c4-validationn-91205-samples.en_text_document"], [1.0], samples, seqlen, seed, shuffle=False) + return BlendedGPTDataset( + [base_dir / "c4-train.en_6_text_document", base_dir / "c4-train.en_7_text_document"], [1.0, 1.0], samples, seqlen, seed, shuffle=True) - for b in range(math.ceil(samples / bs)): - batch = [] - for i in range(bs): - tokens = dataset.get(b * bs + i) - batch.append(tokens) +def iterate_llama3_dataset(dataset:BlendedGPTDataset, bs:int): + for b in range(math.ceil(dataset.samples / bs)): + batch = [dataset.get(b * bs + i) for i in range(bs)] yield Tensor.stack(batch, dim=0) -def batch_load_llama3_small(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True): - if val: - dataset = BlendedGPTDataset([ - base_dir / "c4-validation-91205-samples.en_text_document", - ], [ - 1.0 - ], samples, seqlen, seed, False) - else: - dataset = BlendedGPTDataset([ - base_dir / "c4-train.en_6_text_document", - ], [ - 1.0 - ], samples, seqlen, seed, True) - - for b in range(math.ceil(samples / bs)): - batch = [] - for i in range(bs): - tokens = dataset.get(b * bs + i) - batch.append(tokens) - yield Tensor.stack(batch, dim=0) +def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True, small:bool=False): + return iterate_llama3_dataset(get_llama3_dataset(samples, seqlen, base_dir, seed, val, small), bs) if __name__ == "__main__": def load_unet3d(val): diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index 4a1c1f4e7c..66a0118259 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -234,12 +234,9 @@ def eval_llama3(): loss = logits.sparse_categorical_crossentropy(tokens[:, 1:]) return loss.flatten().float() - if SMALL: - from examples.mlperf.dataloader import batch_load_llama3_small - iter = batch_load_llama3_small(BS, 5760, SEQLEN, BASEDIR, val=True) - else: - from examples.mlperf.dataloader import batch_load_llama3 - iter = batch_load_llama3(BS, 5760, SEQLEN, BASEDIR, val=True) + from examples.mlperf.dataloader import get_llama3_dataset, iterate_llama3_dataset + eval_dataset = get_llama3_dataset(5760, SEQLEN, BASEDIR, val=True, small=bool(SMALL)) + iter = iterate_llama3_dataset(eval_dataset, BS) losses = [] for tokens in tqdm(iter, total=5760//BS): diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 67169833b4..919742de5b 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1424,23 +1424,20 @@ def train_llama3(): if getenv("FAKEDATA", 0): return fake_data(BS, SAMPLES) else: - if SMALL: - from examples.mlperf.dataloader import batch_load_llama3_small - return batch_load_llama3_small(BS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL)) - else: - from examples.mlperf.dataloader import batch_load_llama3 - return batch_load_llama3(BS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL)) + from examples.mlperf.dataloader import batch_load_llama3 + return batch_load_llama3(BS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL), small=bool(SMALL)) + + if getenv("FAKEDATA", 0): + eval_dataset = None + else: + from examples.mlperf.dataloader import get_llama3_dataset + eval_dataset = get_llama3_dataset(5760, SEQLEN, BASEDIR, val=True, small=bool(SMALL)) def get_eval_iter(): - if getenv("FAKEDATA", 0): + if eval_dataset is None: return fake_data(EVAL_BS, 5760) - else: - if SMALL: - from examples.mlperf.dataloader import batch_load_llama3_small - return batch_load_llama3_small(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True) - else: - from examples.mlperf.dataloader import batch_load_llama3 - return batch_load_llama3(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True) + from examples.mlperf.dataloader import iterate_llama3_dataset + return iterate_llama3_dataset(eval_dataset, EVAL_BS) iter = get_train_iter() i, sequences_seen = resume_ckpt, 0