feat: faster index building (#11462)

* feat: faster index building

* feat: correct training samples
This commit is contained in:
wozeparrot
2025-08-02 08:50:18 -07:00
committed by GitHub
parent 8cc2d64edb
commit 3a4deb08d2
2 changed files with 23 additions and 15 deletions

View File

@@ -1,6 +1,4 @@
import functools
import hashlib
import os, random, pickle, queue, struct, math
import os, random, pickle, queue, struct, math, functools, hashlib, time
from typing import List
from pathlib import Path
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
@@ -532,21 +530,21 @@ class BinIdxDataset:
start = self.idx.tell()
end = start + self.count * dtypes.int32.itemsize
self.sizes = self.idx_t[start:end].bitcast(dtypes.int32)
self.sizes = self.idx_t[start:end].bitcast(dtypes.int32).numpy()
start = end
end = start + self.count * dtypes.int64.itemsize
self.pointers = self.idx_t[start:end].bitcast(dtypes.int64)
self.pointers = self.idx_t[start:end].bitcast(dtypes.int64).numpy()
start = end
end = start + doc_count * dtypes.int64.itemsize
self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64)
self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64).numpy()
# bin file
self.bin_t = Tensor(base_path.with_name(f"{base_path.name}.bin"))
def _index(self, idx) -> tuple[int, int]:
return self.pointers[idx].item(), self.sizes[idx].item()
return self.pointers[idx], self.sizes[idx]
def get(self, idx, offset:int=0, length:int|None=None):
ptr, size = self._index(idx)
@@ -628,14 +626,20 @@ class GPTDataset:
# https://github.com/NVIDIA/Megatron-LM/blob/94bd476bd840c2fd4c3ebfc7448c2af220f4832b/megatron/core/datasets/gpt_dataset.py#L558
def _build_doc_idx(self):
doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1]
doc_idx = doc_idx.reshape(-1)
print(f"building doc_idx for {self.num_epochs=}, {self.indexed_dataset.count=}")
st = time.perf_counter()
# doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1]
doc_idx = np.arange(self.indexed_dataset.count).reshape(1, -1).repeat(self.num_epochs, axis=0).flatten()
doc_idx = doc_idx.astype(np.int32)
at = time.perf_counter()
if self.shuffle: self.rng.shuffle(doc_idx)
print(f"doc_idx built in {at - st:.3f}s, shuffled in {time.perf_counter() - at:.3f}s")
return doc_idx
def _build_sample_idx(self):
sample_idx = np.empty((self.samples + 1, 2), dtype=np.int32)
print(f"building sample_idx for {self.samples=}, {self.seqlen=}, {self.doc_idx.shape[0]=}")
sample_idx_max = max(self.doc_idx.shape[0], self.indexed_dataset.sizes.max())
sample_idx = np.empty((self.samples + 1, 2), dtype=np.int64 if sample_idx_max > dtypes.int32.max else np.int32)
sample_idx_idx, doc_idx_idx, doc_offset = 0, 0, 0
sample_idx[sample_idx_idx, 0], sample_idx[sample_idx_idx, 1] = doc_idx_idx, doc_offset
@@ -645,7 +649,7 @@ class GPTDataset:
remaining_seqlen = self.seqlen + 1
while remaining_seqlen > 0:
doc_idx = int(self.doc_idx[doc_idx_idx])
doc_len = self.indexed_dataset.sizes[doc_idx].item() - doc_offset
doc_len = int(self.indexed_dataset.sizes[doc_idx]) - doc_offset
remaining_seqlen -= doc_len
if remaining_seqlen <= 0:
doc_offset += remaining_seqlen + doc_len - 1
@@ -654,7 +658,7 @@ class GPTDataset:
if doc_idx_idx == len(self.doc_idx) - 1:
assert sample_idx_idx == self.samples
doc_idx = int(self.doc_idx[doc_idx_idx])
doc_offset = self.indexed_dataset.sizes[doc_idx].item() - 1
doc_offset = int(self.indexed_dataset.sizes[doc_idx]) - 1
break
doc_idx_idx += 1
doc_offset = 0
@@ -665,8 +669,12 @@ class GPTDataset:
return sample_idx
def _build_shuffle_idx(self):
print(f"building shuffle_idx for {self.samples=}")
st = time.perf_counter()
shuffle_idx = np.arange(self.samples, dtype=np.int32)
at = time.perf_counter()
if self.shuffle: self.rng.shuffle(shuffle_idx)
print(f"shuffle_idx built in {at - st:.3f}s, shuffled in {time.perf_counter() - at:.3f}s")
return shuffle_idx
class BlendedGPTDataset:
@@ -739,8 +747,8 @@ if __name__ == "__main__":
def load_llama3(val):
bs = 24
samples = 5760 if val else 1_200_000
seqlen = 512
samples = 5760 if val else 1_200_000 * 1152
seqlen = 8192
max_, min_ = 0, math.inf
for tokens in tqdm(batch_load_llama3(bs, samples, seqlen, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=5760, val=bool(val)), total=samples//bs):

View File

@@ -1296,7 +1296,7 @@ def train_llama3():
SEED = config["SEED"] = getenv("SEED", 5760)
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000)
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
# trains to 7