Revert "feat: faster index building (#11462)" (#11478)

This reverts commit 3a4deb08d2.
This commit is contained in:
chenyu
2025-08-02 09:50:48 -07:00
committed by GitHub
parent ef7e01cadf
commit f7965f85aa
2 changed files with 15 additions and 23 deletions

View File

@@ -1,4 +1,6 @@
import os, random, pickle, queue, struct, math, functools, hashlib, time
import functools
import hashlib
import os, random, pickle, queue, struct, math
from typing import List
from pathlib import Path
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
@@ -530,21 +532,21 @@ class BinIdxDataset:
start = self.idx.tell()
end = start + self.count * dtypes.int32.itemsize
self.sizes = self.idx_t[start:end].bitcast(dtypes.int32).numpy()
self.sizes = self.idx_t[start:end].bitcast(dtypes.int32)
start = end
end = start + self.count * dtypes.int64.itemsize
self.pointers = self.idx_t[start:end].bitcast(dtypes.int64).numpy()
self.pointers = self.idx_t[start:end].bitcast(dtypes.int64)
start = end
end = start + doc_count * dtypes.int64.itemsize
self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64).numpy()
self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64)
# bin file
self.bin_t = Tensor(base_path.with_name(f"{base_path.name}.bin"))
def _index(self, idx) -> tuple[int, int]:
return self.pointers[idx], self.sizes[idx]
return self.pointers[idx].item(), self.sizes[idx].item()
def get(self, idx, offset:int=0, length:int|None=None):
ptr, size = self._index(idx)
@@ -626,20 +628,14 @@ class GPTDataset:
# https://github.com/NVIDIA/Megatron-LM/blob/94bd476bd840c2fd4c3ebfc7448c2af220f4832b/megatron/core/datasets/gpt_dataset.py#L558
def _build_doc_idx(self):
print(f"building doc_idx for {self.num_epochs=}, {self.indexed_dataset.count=}")
st = time.perf_counter()
# doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1]
doc_idx = np.arange(self.indexed_dataset.count).reshape(1, -1).repeat(self.num_epochs, axis=0).flatten()
doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1]
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
at = time.perf_counter()
if self.shuffle: self.rng.shuffle(doc_idx)
print(f"doc_idx built in {at - st:.3f}s, shuffled in {time.perf_counter() - at:.3f}s")
return doc_idx
def _build_sample_idx(self):
print(f"building sample_idx for {self.samples=}, {self.seqlen=}, {self.doc_idx.shape[0]=}")
sample_idx_max = max(self.doc_idx.shape[0], self.indexed_dataset.sizes.max())
sample_idx = np.empty((self.samples + 1, 2), dtype=np.int64 if sample_idx_max > dtypes.int32.max else np.int32)
sample_idx = np.empty((self.samples + 1, 2), dtype=np.int32)
sample_idx_idx, doc_idx_idx, doc_offset = 0, 0, 0
sample_idx[sample_idx_idx, 0], sample_idx[sample_idx_idx, 1] = doc_idx_idx, doc_offset
@@ -649,7 +645,7 @@ class GPTDataset:
remaining_seqlen = self.seqlen + 1
while remaining_seqlen > 0:
doc_idx = int(self.doc_idx[doc_idx_idx])
doc_len = int(self.indexed_dataset.sizes[doc_idx]) - doc_offset
doc_len = self.indexed_dataset.sizes[doc_idx].item() - doc_offset
remaining_seqlen -= doc_len
if remaining_seqlen <= 0:
doc_offset += remaining_seqlen + doc_len - 1
@@ -658,7 +654,7 @@ class GPTDataset:
if doc_idx_idx == len(self.doc_idx) - 1:
assert sample_idx_idx == self.samples
doc_idx = int(self.doc_idx[doc_idx_idx])
doc_offset = int(self.indexed_dataset.sizes[doc_idx]) - 1
doc_offset = self.indexed_dataset.sizes[doc_idx].item() - 1
break
doc_idx_idx += 1
doc_offset = 0
@@ -669,12 +665,8 @@ class GPTDataset:
return sample_idx
def _build_shuffle_idx(self):
print(f"building shuffle_idx for {self.samples=}")
st = time.perf_counter()
shuffle_idx = np.arange(self.samples, dtype=np.int32)
at = time.perf_counter()
if self.shuffle: self.rng.shuffle(shuffle_idx)
print(f"shuffle_idx built in {at - st:.3f}s, shuffled in {time.perf_counter() - at:.3f}s")
return shuffle_idx
class BlendedGPTDataset:
@@ -747,8 +739,8 @@ if __name__ == "__main__":
def load_llama3(val):
bs = 24
samples = 5760 if val else 1_200_000 * 1152
seqlen = 8192
samples = 5760 if val else 1_200_000
seqlen = 512
max_, min_ = 0, math.inf
for tokens in tqdm(batch_load_llama3(bs, samples, seqlen, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=5760, val=bool(val)), total=samples//bs):

View File

@@ -1296,7 +1296,7 @@ def train_llama3():
SEED = config["SEED"] = getenv("SEED", 5760)
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000)
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
# trains to 7