mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
feat: faster index building (#11462)
* feat: faster index building * feat: correct training samples
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
import functools
|
||||
import hashlib
|
||||
import os, random, pickle, queue, struct, math
|
||||
import os, random, pickle, queue, struct, math, functools, hashlib, time
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
|
||||
@@ -532,21 +530,21 @@ class BinIdxDataset:
|
||||
|
||||
start = self.idx.tell()
|
||||
end = start + self.count * dtypes.int32.itemsize
|
||||
self.sizes = self.idx_t[start:end].bitcast(dtypes.int32)
|
||||
self.sizes = self.idx_t[start:end].bitcast(dtypes.int32).numpy()
|
||||
|
||||
start = end
|
||||
end = start + self.count * dtypes.int64.itemsize
|
||||
self.pointers = self.idx_t[start:end].bitcast(dtypes.int64)
|
||||
self.pointers = self.idx_t[start:end].bitcast(dtypes.int64).numpy()
|
||||
|
||||
start = end
|
||||
end = start + doc_count * dtypes.int64.itemsize
|
||||
self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64)
|
||||
self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64).numpy()
|
||||
|
||||
# bin file
|
||||
self.bin_t = Tensor(base_path.with_name(f"{base_path.name}.bin"))
|
||||
|
||||
def _index(self, idx) -> tuple[int, int]:
|
||||
return self.pointers[idx].item(), self.sizes[idx].item()
|
||||
return self.pointers[idx], self.sizes[idx]
|
||||
|
||||
def get(self, idx, offset:int=0, length:int|None=None):
|
||||
ptr, size = self._index(idx)
|
||||
@@ -628,14 +626,20 @@ class GPTDataset:
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/94bd476bd840c2fd4c3ebfc7448c2af220f4832b/megatron/core/datasets/gpt_dataset.py#L558
|
||||
def _build_doc_idx(self):
|
||||
doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1]
|
||||
doc_idx = doc_idx.reshape(-1)
|
||||
print(f"building doc_idx for {self.num_epochs=}, {self.indexed_dataset.count=}")
|
||||
st = time.perf_counter()
|
||||
# doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1]
|
||||
doc_idx = np.arange(self.indexed_dataset.count).reshape(1, -1).repeat(self.num_epochs, axis=0).flatten()
|
||||
doc_idx = doc_idx.astype(np.int32)
|
||||
at = time.perf_counter()
|
||||
if self.shuffle: self.rng.shuffle(doc_idx)
|
||||
print(f"doc_idx built in {at - st:.3f}s, shuffled in {time.perf_counter() - at:.3f}s")
|
||||
return doc_idx
|
||||
|
||||
def _build_sample_idx(self):
|
||||
sample_idx = np.empty((self.samples + 1, 2), dtype=np.int32)
|
||||
print(f"building sample_idx for {self.samples=}, {self.seqlen=}, {self.doc_idx.shape[0]=}")
|
||||
sample_idx_max = max(self.doc_idx.shape[0], self.indexed_dataset.sizes.max())
|
||||
sample_idx = np.empty((self.samples + 1, 2), dtype=np.int64 if sample_idx_max > dtypes.int32.max else np.int32)
|
||||
|
||||
sample_idx_idx, doc_idx_idx, doc_offset = 0, 0, 0
|
||||
sample_idx[sample_idx_idx, 0], sample_idx[sample_idx_idx, 1] = doc_idx_idx, doc_offset
|
||||
@@ -645,7 +649,7 @@ class GPTDataset:
|
||||
remaining_seqlen = self.seqlen + 1
|
||||
while remaining_seqlen > 0:
|
||||
doc_idx = int(self.doc_idx[doc_idx_idx])
|
||||
doc_len = self.indexed_dataset.sizes[doc_idx].item() - doc_offset
|
||||
doc_len = int(self.indexed_dataset.sizes[doc_idx]) - doc_offset
|
||||
remaining_seqlen -= doc_len
|
||||
if remaining_seqlen <= 0:
|
||||
doc_offset += remaining_seqlen + doc_len - 1
|
||||
@@ -654,7 +658,7 @@ class GPTDataset:
|
||||
if doc_idx_idx == len(self.doc_idx) - 1:
|
||||
assert sample_idx_idx == self.samples
|
||||
doc_idx = int(self.doc_idx[doc_idx_idx])
|
||||
doc_offset = self.indexed_dataset.sizes[doc_idx].item() - 1
|
||||
doc_offset = int(self.indexed_dataset.sizes[doc_idx]) - 1
|
||||
break
|
||||
doc_idx_idx += 1
|
||||
doc_offset = 0
|
||||
@@ -665,8 +669,12 @@ class GPTDataset:
|
||||
return sample_idx
|
||||
|
||||
def _build_shuffle_idx(self):
|
||||
print(f"building shuffle_idx for {self.samples=}")
|
||||
st = time.perf_counter()
|
||||
shuffle_idx = np.arange(self.samples, dtype=np.int32)
|
||||
at = time.perf_counter()
|
||||
if self.shuffle: self.rng.shuffle(shuffle_idx)
|
||||
print(f"shuffle_idx built in {at - st:.3f}s, shuffled in {time.perf_counter() - at:.3f}s")
|
||||
return shuffle_idx
|
||||
|
||||
class BlendedGPTDataset:
|
||||
@@ -739,8 +747,8 @@ if __name__ == "__main__":
|
||||
|
||||
def load_llama3(val):
|
||||
bs = 24
|
||||
samples = 5760 if val else 1_200_000
|
||||
seqlen = 512
|
||||
samples = 5760 if val else 1_200_000 * 1152
|
||||
seqlen = 8192
|
||||
|
||||
max_, min_ = 0, math.inf
|
||||
for tokens in tqdm(batch_load_llama3(bs, samples, seqlen, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=5760, val=bool(val)), total=samples//bs):
|
||||
|
||||
@@ -1296,7 +1296,7 @@ def train_llama3():
|
||||
SEED = config["SEED"] = getenv("SEED", 5760)
|
||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
|
||||
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000)
|
||||
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
|
||||
|
||||
# LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
# trains to 7
|
||||
|
||||
Reference in New Issue
Block a user