From 3a4deb08d238a9ce175b943efca2549cdf8fc8bb Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Sat, 2 Aug 2025 08:50:18 -0700 Subject: [PATCH] feat: faster index building (#11462) * feat: faster index building * feat: correct training samples --- examples/mlperf/dataloader.py | 36 +++++++++++++++++++++------------- examples/mlperf/model_train.py | 2 +- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 13fe3b9405..503ea153b9 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -1,6 +1,4 @@ -import functools -import hashlib -import os, random, pickle, queue, struct, math +import os, random, pickle, queue, struct, math, functools, hashlib, time from typing import List from pathlib import Path from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count @@ -532,21 +530,21 @@ class BinIdxDataset: start = self.idx.tell() end = start + self.count * dtypes.int32.itemsize - self.sizes = self.idx_t[start:end].bitcast(dtypes.int32) + self.sizes = self.idx_t[start:end].bitcast(dtypes.int32).numpy() start = end end = start + self.count * dtypes.int64.itemsize - self.pointers = self.idx_t[start:end].bitcast(dtypes.int64) + self.pointers = self.idx_t[start:end].bitcast(dtypes.int64).numpy() start = end end = start + doc_count * dtypes.int64.itemsize - self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64) + self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64).numpy() # bin file self.bin_t = Tensor(base_path.with_name(f"{base_path.name}.bin")) def _index(self, idx) -> tuple[int, int]: - return self.pointers[idx].item(), self.sizes[idx].item() + return self.pointers[idx], self.sizes[idx] def get(self, idx, offset:int=0, length:int|None=None): ptr, size = self._index(idx) @@ -628,14 +626,20 @@ class GPTDataset: # https://github.com/NVIDIA/Megatron-LM/blob/94bd476bd840c2fd4c3ebfc7448c2af220f4832b/megatron/core/datasets/gpt_dataset.py#L558 def _build_doc_idx(self): - doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1] - doc_idx = doc_idx.reshape(-1) + print(f"building doc_idx for {self.num_epochs=}, {self.indexed_dataset.count=}") + st = time.perf_counter() + # doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1] + doc_idx = np.arange(self.indexed_dataset.count).reshape(1, -1).repeat(self.num_epochs, axis=0).flatten() doc_idx = doc_idx.astype(np.int32) + at = time.perf_counter() if self.shuffle: self.rng.shuffle(doc_idx) + print(f"doc_idx built in {at - st:.3f}s, shuffled in {time.perf_counter() - at:.3f}s") return doc_idx def _build_sample_idx(self): - sample_idx = np.empty((self.samples + 1, 2), dtype=np.int32) + print(f"building sample_idx for {self.samples=}, {self.seqlen=}, {self.doc_idx.shape[0]=}") + sample_idx_max = max(self.doc_idx.shape[0], self.indexed_dataset.sizes.max()) + sample_idx = np.empty((self.samples + 1, 2), dtype=np.int64 if sample_idx_max > dtypes.int32.max else np.int32) sample_idx_idx, doc_idx_idx, doc_offset = 0, 0, 0 sample_idx[sample_idx_idx, 0], sample_idx[sample_idx_idx, 1] = doc_idx_idx, doc_offset @@ -645,7 +649,7 @@ class GPTDataset: remaining_seqlen = self.seqlen + 1 while remaining_seqlen > 0: doc_idx = int(self.doc_idx[doc_idx_idx]) - doc_len = self.indexed_dataset.sizes[doc_idx].item() - doc_offset + doc_len = int(self.indexed_dataset.sizes[doc_idx]) - doc_offset remaining_seqlen -= doc_len if remaining_seqlen <= 0: doc_offset += remaining_seqlen + doc_len - 1 @@ -654,7 +658,7 @@ class GPTDataset: if doc_idx_idx == len(self.doc_idx) - 1: assert sample_idx_idx == self.samples doc_idx = int(self.doc_idx[doc_idx_idx]) - doc_offset = self.indexed_dataset.sizes[doc_idx].item() - 1 + doc_offset = int(self.indexed_dataset.sizes[doc_idx]) - 1 break doc_idx_idx += 1 doc_offset = 0 @@ -665,8 +669,12 @@ class GPTDataset: return sample_idx def _build_shuffle_idx(self): + print(f"building shuffle_idx for {self.samples=}") + st = time.perf_counter() shuffle_idx = np.arange(self.samples, dtype=np.int32) + at = time.perf_counter() if self.shuffle: self.rng.shuffle(shuffle_idx) + print(f"shuffle_idx built in {at - st:.3f}s, shuffled in {time.perf_counter() - at:.3f}s") return shuffle_idx class BlendedGPTDataset: @@ -739,8 +747,8 @@ if __name__ == "__main__": def load_llama3(val): bs = 24 - samples = 5760 if val else 1_200_000 - seqlen = 512 + samples = 5760 if val else 1_200_000 * 1152 + seqlen = 8192 max_, min_ = 0, math.inf for tokens in tqdm(batch_load_llama3(bs, samples, seqlen, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=5760, val=bool(val)), total=samples//bs): diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index ae13e091a8..8505fb70fa 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1296,7 +1296,7 @@ def train_llama3(): SEED = config["SEED"] = getenv("SEED", 5760) SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192) TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0) - SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000) + SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152) # LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py # trains to 7