update llama dataloader (#13825)

separate creating dataset from itererating over the dataset to not create eval data for each eval
This commit is contained in:
chenyu
2025-12-24 17:42:08 -05:00
committed by GitHub
parent a7fc0c288b
commit da1cb6a9ec
3 changed files with 30 additions and 58 deletions

View File

@@ -763,48 +763,26 @@ class BlendedGPTDataset:
return dataset_idx, dataset_sample_idx
def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True):
def get_llama3_dataset(samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True, small:bool=False) -> BlendedGPTDataset:
if small:
if val:
return BlendedGPTDataset(
[base_dir / "c4-validation-91205-samples.en_text_document"], [1.0], samples, seqlen, seed, shuffle=False)
return BlendedGPTDataset(
[base_dir / "c4-train.en_6_text_document"], [1.0], samples, seqlen, seed, shuffle=True)
if val:
dataset = BlendedGPTDataset([
base_dir / "validation" / "c4-validationn-91205-samples.en_text_document",
], [
1.0
], samples, seqlen, seed, False)
else:
dataset = BlendedGPTDataset([
base_dir / "c4-train.en_6_text_document",
base_dir / "c4-train.en_7_text_document",
], [
1.0, 1.0
], samples, seqlen, seed, True)
return BlendedGPTDataset(
[base_dir / "validation" / "c4-validationn-91205-samples.en_text_document"], [1.0], samples, seqlen, seed, shuffle=False)
return BlendedGPTDataset(
[base_dir / "c4-train.en_6_text_document", base_dir / "c4-train.en_7_text_document"], [1.0, 1.0], samples, seqlen, seed, shuffle=True)
for b in range(math.ceil(samples / bs)):
batch = []
for i in range(bs):
tokens = dataset.get(b * bs + i)
batch.append(tokens)
def iterate_llama3_dataset(dataset:BlendedGPTDataset, bs:int):
for b in range(math.ceil(dataset.samples / bs)):
batch = [dataset.get(b * bs + i) for i in range(bs)]
yield Tensor.stack(batch, dim=0)
def batch_load_llama3_small(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True):
if val:
dataset = BlendedGPTDataset([
base_dir / "c4-validation-91205-samples.en_text_document",
], [
1.0
], samples, seqlen, seed, False)
else:
dataset = BlendedGPTDataset([
base_dir / "c4-train.en_6_text_document",
], [
1.0
], samples, seqlen, seed, True)
for b in range(math.ceil(samples / bs)):
batch = []
for i in range(bs):
tokens = dataset.get(b * bs + i)
batch.append(tokens)
yield Tensor.stack(batch, dim=0)
def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True, small:bool=False):
return iterate_llama3_dataset(get_llama3_dataset(samples, seqlen, base_dir, seed, val, small), bs)
if __name__ == "__main__":
def load_unet3d(val):

View File

@@ -234,12 +234,9 @@ def eval_llama3():
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
return loss.flatten().float()
if SMALL:
from examples.mlperf.dataloader import batch_load_llama3_small
iter = batch_load_llama3_small(BS, 5760, SEQLEN, BASEDIR, val=True)
else:
from examples.mlperf.dataloader import batch_load_llama3
iter = batch_load_llama3(BS, 5760, SEQLEN, BASEDIR, val=True)
from examples.mlperf.dataloader import get_llama3_dataset, iterate_llama3_dataset
eval_dataset = get_llama3_dataset(5760, SEQLEN, BASEDIR, val=True, small=bool(SMALL))
iter = iterate_llama3_dataset(eval_dataset, BS)
losses = []
for tokens in tqdm(iter, total=5760//BS):

View File

@@ -1424,23 +1424,20 @@ def train_llama3():
if getenv("FAKEDATA", 0):
return fake_data(BS, SAMPLES)
else:
if SMALL:
from examples.mlperf.dataloader import batch_load_llama3_small
return batch_load_llama3_small(BS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL))
else:
from examples.mlperf.dataloader import batch_load_llama3
return batch_load_llama3(BS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL))
from examples.mlperf.dataloader import batch_load_llama3
return batch_load_llama3(BS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL), small=bool(SMALL))
if getenv("FAKEDATA", 0):
eval_dataset = None
else:
from examples.mlperf.dataloader import get_llama3_dataset
eval_dataset = get_llama3_dataset(5760, SEQLEN, BASEDIR, val=True, small=bool(SMALL))
def get_eval_iter():
if getenv("FAKEDATA", 0):
if eval_dataset is None:
return fake_data(EVAL_BS, 5760)
else:
if SMALL:
from examples.mlperf.dataloader import batch_load_llama3_small
return batch_load_llama3_small(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True)
else:
from examples.mlperf.dataloader import batch_load_llama3
return batch_load_llama3(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True)
from examples.mlperf.dataloader import iterate_llama3_dataset
return iterate_llama3_dataset(eval_dataset, EVAL_BS)
iter = get_train_iter()
i, sequences_seen = resume_ckpt, 0