mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
feat: llama3 dataloader (#11340)
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import os, random, pickle, queue
|
||||
import functools
|
||||
import hashlib
|
||||
import os, random, pickle, queue, struct, math
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
|
||||
@@ -6,6 +8,7 @@ from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu
|
||||
import numpy as np
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import getenv, prod, Context, round_up, tqdm, OSX
|
||||
from tinygrad.nn.state import TensorIO
|
||||
|
||||
### ResNet
|
||||
|
||||
@@ -510,6 +513,202 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
# happens with BENCHMARK set
|
||||
pass
|
||||
|
||||
# llama3
|
||||
|
||||
class BinIdxDataset:
|
||||
def __init__(self, base_path:Path):
|
||||
self.idx_t = Tensor(base_path.with_name(f"{base_path.name}.idx"))
|
||||
self.idx = TensorIO(self.idx_t)
|
||||
|
||||
# parse idx file
|
||||
magic = self.idx.read(9)
|
||||
assert magic == b"MMIDIDX\x00\x00", "invalid index file format"
|
||||
version, = struct.unpack("<Q", self.idx.read(8))
|
||||
assert version == 1, "unsupported index version"
|
||||
dtype_code, = struct.unpack("<B", self.idx.read(1))
|
||||
self.dtype = {1:dtypes.uint8, 2:dtypes.int8, 3:dtypes.int16, 4:dtypes.int32, 5:dtypes.int64, 6:dtypes.float64, 7:dtypes.double, 8:dtypes.uint16}[dtype_code]
|
||||
self.count, = struct.unpack("<Q", self.idx.read(8))
|
||||
doc_count, = struct.unpack("<Q", self.idx.read(8))
|
||||
|
||||
start = self.idx.tell()
|
||||
end = start + self.count * dtypes.int32.itemsize
|
||||
self.sizes = self.idx_t[start:end].bitcast(dtypes.int32)
|
||||
|
||||
start = end
|
||||
end = start + self.count * dtypes.int64.itemsize
|
||||
self.pointers = self.idx_t[start:end].bitcast(dtypes.int64)
|
||||
|
||||
start = end
|
||||
end = start + doc_count * dtypes.int64.itemsize
|
||||
self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64)
|
||||
|
||||
# bin file
|
||||
self.bin_t = Tensor(base_path.with_name(f"{base_path.name}.bin"))
|
||||
|
||||
def _index(self, idx) -> tuple[int, int]:
|
||||
return self.pointers[idx].item(), self.sizes[idx].item()
|
||||
|
||||
def get(self, idx, offset:int=0, length:int|None=None):
|
||||
ptr, size = self._index(idx)
|
||||
if length is None: length = size - offset
|
||||
ptr += offset * self.dtype.itemsize
|
||||
return self.bin_t[ptr:ptr+length*self.dtype.itemsize].bitcast(self.dtype).to(None)
|
||||
|
||||
# https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/datasets.html
|
||||
class GPTDataset:
|
||||
def __init__(self, base_path:Path, samples:int, seqlen:int, seed:int, shuffle:bool):
|
||||
self.samples, self.seqlen = samples, seqlen
|
||||
self.shuffle = shuffle
|
||||
self.rng = np.random.RandomState(seed)
|
||||
|
||||
self.indexed_dataset = BinIdxDataset(base_path)
|
||||
|
||||
# check for cache
|
||||
cache_hash = hashlib.sha256(f"{samples}:{seqlen}:{seed}:{shuffle}".encode()).hexdigest()
|
||||
cache_path = base_path.with_name(f"{base_path.name}.{cache_hash}.index_cache")
|
||||
if cache_path.exists():
|
||||
with open(cache_path, "rb") as f:
|
||||
self.doc_idx, self.sample_idx, self.shuffle_idx = pickle.load(f)
|
||||
else:
|
||||
self.doc_idx = self._build_doc_idx()
|
||||
self.sample_idx = self._build_sample_idx()
|
||||
self.shuffle_idx = self._build_shuffle_idx()
|
||||
# save cache
|
||||
with open(cache_path, "wb") as f:
|
||||
pickle.dump((self.doc_idx, self.sample_idx, self.shuffle_idx), f)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx is None:
|
||||
text = self._get(0)
|
||||
else:
|
||||
text = self._get(idx)
|
||||
|
||||
return text
|
||||
|
||||
def _get(self, idx):
|
||||
idx = self.shuffle_idx[idx]
|
||||
|
||||
doc_idx_beg, doc_idx_beg_offset = self.sample_idx[idx]
|
||||
doc_idx_end, doc_idx_end_offset = self.sample_idx[idx + 1]
|
||||
|
||||
doc_ids, sample_parts = [], []
|
||||
|
||||
if doc_idx_beg == doc_idx_end:
|
||||
doc_ids.append(self.doc_idx[doc_idx_beg])
|
||||
|
||||
sample_parts.append(
|
||||
self.indexed_dataset.get(
|
||||
int(self.doc_idx[doc_idx_beg]), offset=int(doc_idx_beg_offset), length=int(doc_idx_end_offset - doc_idx_beg_offset + 1)))
|
||||
else:
|
||||
for i in range(doc_idx_beg, doc_idx_end + 1):
|
||||
doc_ids.append(self.doc_idx[i])
|
||||
|
||||
offset = 0 if i > doc_idx_beg else doc_idx_beg_offset
|
||||
length = None if i < doc_idx_end else int(doc_idx_end_offset + 1)
|
||||
sample_parts.append(self.indexed_dataset.get(int(self.doc_idx[i]), offset=int(offset), length=length))
|
||||
|
||||
# concat all parts
|
||||
text = Tensor.cat(*sample_parts)
|
||||
|
||||
return text
|
||||
|
||||
@functools.cached_property
|
||||
def tokens_per_epoch(self) -> int:
|
||||
return sum(self.indexed_dataset.sizes.tolist())
|
||||
|
||||
@functools.cached_property
|
||||
def num_epochs(self) -> int:
|
||||
# we need enough epochs to cover the requested amount of tokens
|
||||
num_epochs = 1
|
||||
num_tokens = self.tokens_per_epoch
|
||||
while num_tokens < self.samples * self.seqlen:
|
||||
num_epochs += 1
|
||||
num_tokens += self.tokens_per_epoch
|
||||
return num_epochs
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/94bd476bd840c2fd4c3ebfc7448c2af220f4832b/megatron/core/datasets/gpt_dataset.py#L558
|
||||
def _build_doc_idx(self):
|
||||
doc_idx = np.mgrid[:self.num_epochs, :self.indexed_dataset.count][1]
|
||||
doc_idx = doc_idx.reshape(-1)
|
||||
doc_idx = doc_idx.astype(np.int32)
|
||||
if self.shuffle: self.rng.shuffle(doc_idx)
|
||||
return doc_idx
|
||||
|
||||
def _build_sample_idx(self):
|
||||
sample_idx = np.empty((self.samples + 1, 2), dtype=np.int32)
|
||||
|
||||
sample_idx_idx, doc_idx_idx, doc_offset = 0, 0, 0
|
||||
sample_idx[sample_idx_idx, 0], sample_idx[sample_idx_idx, 1] = doc_idx_idx, doc_offset
|
||||
sample_idx_idx += 1
|
||||
|
||||
for _ in tqdm(range(1, self.samples + 1)):
|
||||
remaining_seqlen = self.seqlen + 1
|
||||
while remaining_seqlen > 0:
|
||||
doc_idx = int(self.doc_idx[doc_idx_idx])
|
||||
doc_len = self.indexed_dataset.sizes[doc_idx].item() - doc_offset
|
||||
remaining_seqlen -= doc_len
|
||||
if remaining_seqlen <= 0:
|
||||
doc_offset += remaining_seqlen + doc_len - 1
|
||||
remaining_seqlen = 0
|
||||
else:
|
||||
if doc_idx_idx == len(self.doc_idx) - 1:
|
||||
assert sample_idx_idx == self.samples
|
||||
doc_idx = int(self.doc_idx[doc_idx_idx])
|
||||
doc_offset = self.indexed_dataset.sizes[doc_idx].item() - 1
|
||||
break
|
||||
doc_idx_idx += 1
|
||||
doc_offset = 0
|
||||
|
||||
sample_idx[sample_idx_idx, 0], sample_idx[sample_idx_idx, 1] = doc_idx_idx, doc_offset
|
||||
sample_idx_idx += 1
|
||||
|
||||
return sample_idx
|
||||
|
||||
def _build_shuffle_idx(self):
|
||||
shuffle_idx = np.arange(self.samples, dtype=np.int32)
|
||||
if self.shuffle: self.rng.shuffle(shuffle_idx)
|
||||
return shuffle_idx
|
||||
|
||||
class BlendedGPTDataset:
|
||||
def __init__(self, paths:list[Path], weights:list[float], samples:int, seqlen:int, seed:int, shuffle:bool):
|
||||
self.seed = seed
|
||||
|
||||
# normalize weights
|
||||
total_weight = sum(weights)
|
||||
self.weights = [w / total_weight for w in weights]
|
||||
|
||||
self.samples = samples
|
||||
surplus = 0.005
|
||||
samples_per_blend = [math.ceil(math.ceil(self.samples * w) * (1 + surplus)) for w in self.weights]
|
||||
|
||||
self.datasets = [GPTDataset(path, samples_per_blend[i], seqlen, seed + i, shuffle) for i,path in enumerate(paths)]
|
||||
|
||||
def get(self, idx:int):
|
||||
tokens = self.datasets[0][idx]
|
||||
return tokens
|
||||
|
||||
def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True):
|
||||
if val:
|
||||
dataset = BlendedGPTDataset([
|
||||
base_dir / "validation" / "c4-validationn-91205-samples.en_text_document",
|
||||
], [
|
||||
1.0
|
||||
], samples, seqlen, seed, False)
|
||||
else:
|
||||
dataset = BlendedGPTDataset([
|
||||
base_dir / "c4-train.en_6_text_document",
|
||||
base_dir / "c4-train.en_7_text_document",
|
||||
], [
|
||||
1.0, 1.0
|
||||
], samples, seqlen, seed, True)
|
||||
|
||||
for b in range(math.ceil(samples / bs)):
|
||||
batch = []
|
||||
for i in range(bs):
|
||||
tokens = dataset.get(b * bs + i)
|
||||
batch.append(tokens)
|
||||
yield Tensor.stack(batch, dim=0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def load_unet3d(val):
|
||||
assert not val, "validation set is not supported due to different sizes on inputs"
|
||||
@@ -538,6 +737,18 @@ if __name__ == "__main__":
|
||||
for x in batch_load_retinanet(dataset, val, base_dir):
|
||||
pbar.update(x[0].shape[0])
|
||||
|
||||
def load_llama3(val):
|
||||
bs = 24
|
||||
samples = 5760 if val else 1_200_000
|
||||
seqlen = 512
|
||||
|
||||
max_, min_ = 0, math.inf
|
||||
for tokens in tqdm(batch_load_llama3(bs, samples, seqlen, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=5760, val=bool(val)), total=samples//bs):
|
||||
max_ = max(max_, tokens.shape[1])
|
||||
min_ = min(min_, tokens.shape[1])
|
||||
print(f"max seq length: {max_}")
|
||||
print(f"min seq length: {min_}")
|
||||
|
||||
load_fn_name = f"load_{getenv('MODEL', 'resnet')}"
|
||||
if load_fn_name in globals():
|
||||
globals()[load_fn_name](getenv("VAL", 1))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import time
|
||||
import time, math
|
||||
start = time.perf_counter()
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
@@ -241,6 +241,34 @@ def eval_mrcnn():
|
||||
evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
|
||||
evaluate_predictions_on_coco(mask_output, iou_type='segm')
|
||||
|
||||
def eval_llama3():
|
||||
from extra.models.llama import Transformer
|
||||
from examples.llama3 import MODEL_PARAMS
|
||||
from tinygrad.helpers import tqdm
|
||||
|
||||
bs = 4
|
||||
sequence_length = 512
|
||||
|
||||
model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=sequence_length, jit=False, disable_kv_cache=True)
|
||||
|
||||
@TinyJit
|
||||
def eval_step(model, tokens):
|
||||
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
return loss.flatten()
|
||||
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
iter = batch_load_llama3(bs, 5760, sequence_length, Path(getenv("BASEDIR", "/raid/datasets/c4/")), True)
|
||||
|
||||
losses = []
|
||||
for tokens in tqdm(iter, total=5760//bs):
|
||||
GlobalCounters.reset()
|
||||
losses += eval_step(model, tokens).tolist()
|
||||
tqdm.write(f"loss: {np.mean(losses)}")
|
||||
|
||||
log_perplexity = Tensor(losses).mean()
|
||||
print(f"Log Perplexity: {log_perplexity.item()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# inference only
|
||||
Tensor.training = False
|
||||
|
||||
@@ -1290,9 +1290,12 @@ def train_llama3():
|
||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
|
||||
|
||||
config = {}
|
||||
BS = config["BS"] = getenv("BS", 4)
|
||||
BS = config["BS"] = getenv("BS", 16)
|
||||
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
||||
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
||||
SEED = config["SEED"] = getenv("SEED", 5760)
|
||||
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 1_200_000)
|
||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||
|
||||
opt_adamw_beta_1 = 0.9
|
||||
opt_adamw_beta_2 = 0.95
|
||||
@@ -1300,7 +1303,6 @@ def train_llama3():
|
||||
opt_adamw_weight_decay = 0.1
|
||||
|
||||
opt_gradient_clip_norm = 1.0
|
||||
sequence_length = 8192
|
||||
opt_learning_rate_warmup_steps = getenv("WARMUP_STEPS", math.ceil(8000 * 1152 / GBS))
|
||||
opt_learning_rate_decay_steps = getenv("DECAY_STEPS", math.ceil(1_200_000 * 1152 / GBS) - opt_learning_rate_warmup_steps)
|
||||
opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark
|
||||
@@ -1308,7 +1310,7 @@ def train_llama3():
|
||||
|
||||
# TODO: confirm weights are in bf16
|
||||
# vocab_size from the mixtral tokenizer
|
||||
model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=sequence_length, jit=False, disable_kv_cache=True)
|
||||
model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||
|
||||
optim = AdamW(get_parameters(model), lr=0.0,
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
|
||||
@@ -1316,10 +1318,10 @@ def train_llama3():
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(model, x, y):
|
||||
def train_step(model, tokens):
|
||||
optim.zero_grad()
|
||||
logits:Tensor = model(x, start_pos=0, temperature=math.nan)
|
||||
loss = logits.cross_entropy(y)
|
||||
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
loss.backward()
|
||||
|
||||
# L2 norm grad clip
|
||||
@@ -1341,18 +1343,29 @@ def train_llama3():
|
||||
return loss, lr
|
||||
|
||||
# overfitting this example should give cross_entropy log(BS)
|
||||
fake_input = Tensor([list(range(getenv("SEQLEN", 10)))], dtype="int16").expand(BS, -1)
|
||||
fake_label = Tensor(list(range(BS)), dtype="int16")
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=False)
|
||||
|
||||
for _ in range(100):
|
||||
i = 0
|
||||
for tokens in tqdm(iter, total=SAMPLES//BS):
|
||||
GlobalCounters.reset()
|
||||
loss, lr = train_step(model, fake_input, fake_label)
|
||||
loss, lr = train_step(model, tokens)
|
||||
# BS=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=8B WARMUP_STEPS=2 DECAY_STEPS=300 PYTHONPATH=. AMD=1 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
# uses 43% ~= 83GB
|
||||
# 8B bf16 = 16GB. model + grad + optim m and v = 64GB
|
||||
# TODO: this OOM
|
||||
# BS=1 SEQLEN=4000 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=8B WARMUP_STEPS=2 DECAY_STEPS=300 PYTHONPATH=. AMD=1 MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
print(loss.item(), lr.item(), f"{GlobalCounters.global_mem//10**9=}")
|
||||
# above as tqdm.write f-string
|
||||
tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used")
|
||||
with open("loss.txt", "a") as f:
|
||||
f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
|
||||
if i % 200 == 0 or i == 10:
|
||||
tqdm.write("saving checkpoint")
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/{i}.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
i += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
|
||||
@@ -185,10 +185,10 @@ class Transformer:
|
||||
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1) if seqlen > 1 else None
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h)).float()[:, -1, :]
|
||||
logits = self.output(self.norm(h)).float()
|
||||
if math.isnan(temperature): return logits
|
||||
|
||||
return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
return sample(logits[:, -1, :].flatten(), temperature, top_k, top_p, alpha_f, alpha_p)
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
|
||||
Reference in New Issue
Block a user