mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
update llama dataloader (#13825)
separate creating dataset from itererating over the dataset to not create eval data for each eval
This commit is contained in:
@@ -763,48 +763,26 @@ class BlendedGPTDataset:
|
||||
|
||||
return dataset_idx, dataset_sample_idx
|
||||
|
||||
def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True):
|
||||
def get_llama3_dataset(samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True, small:bool=False) -> BlendedGPTDataset:
|
||||
if small:
|
||||
if val:
|
||||
return BlendedGPTDataset(
|
||||
[base_dir / "c4-validation-91205-samples.en_text_document"], [1.0], samples, seqlen, seed, shuffle=False)
|
||||
return BlendedGPTDataset(
|
||||
[base_dir / "c4-train.en_6_text_document"], [1.0], samples, seqlen, seed, shuffle=True)
|
||||
if val:
|
||||
dataset = BlendedGPTDataset([
|
||||
base_dir / "validation" / "c4-validationn-91205-samples.en_text_document",
|
||||
], [
|
||||
1.0
|
||||
], samples, seqlen, seed, False)
|
||||
else:
|
||||
dataset = BlendedGPTDataset([
|
||||
base_dir / "c4-train.en_6_text_document",
|
||||
base_dir / "c4-train.en_7_text_document",
|
||||
], [
|
||||
1.0, 1.0
|
||||
], samples, seqlen, seed, True)
|
||||
return BlendedGPTDataset(
|
||||
[base_dir / "validation" / "c4-validationn-91205-samples.en_text_document"], [1.0], samples, seqlen, seed, shuffle=False)
|
||||
return BlendedGPTDataset(
|
||||
[base_dir / "c4-train.en_6_text_document", base_dir / "c4-train.en_7_text_document"], [1.0, 1.0], samples, seqlen, seed, shuffle=True)
|
||||
|
||||
for b in range(math.ceil(samples / bs)):
|
||||
batch = []
|
||||
for i in range(bs):
|
||||
tokens = dataset.get(b * bs + i)
|
||||
batch.append(tokens)
|
||||
def iterate_llama3_dataset(dataset:BlendedGPTDataset, bs:int):
|
||||
for b in range(math.ceil(dataset.samples / bs)):
|
||||
batch = [dataset.get(b * bs + i) for i in range(bs)]
|
||||
yield Tensor.stack(batch, dim=0)
|
||||
|
||||
def batch_load_llama3_small(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True):
|
||||
if val:
|
||||
dataset = BlendedGPTDataset([
|
||||
base_dir / "c4-validation-91205-samples.en_text_document",
|
||||
], [
|
||||
1.0
|
||||
], samples, seqlen, seed, False)
|
||||
else:
|
||||
dataset = BlendedGPTDataset([
|
||||
base_dir / "c4-train.en_6_text_document",
|
||||
], [
|
||||
1.0
|
||||
], samples, seqlen, seed, True)
|
||||
|
||||
for b in range(math.ceil(samples / bs)):
|
||||
batch = []
|
||||
for i in range(bs):
|
||||
tokens = dataset.get(b * bs + i)
|
||||
batch.append(tokens)
|
||||
yield Tensor.stack(batch, dim=0)
|
||||
def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True, small:bool=False):
|
||||
return iterate_llama3_dataset(get_llama3_dataset(samples, seqlen, base_dir, seed, val, small), bs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def load_unet3d(val):
|
||||
|
||||
@@ -234,12 +234,9 @@ def eval_llama3():
|
||||
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
return loss.flatten().float()
|
||||
|
||||
if SMALL:
|
||||
from examples.mlperf.dataloader import batch_load_llama3_small
|
||||
iter = batch_load_llama3_small(BS, 5760, SEQLEN, BASEDIR, val=True)
|
||||
else:
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
iter = batch_load_llama3(BS, 5760, SEQLEN, BASEDIR, val=True)
|
||||
from examples.mlperf.dataloader import get_llama3_dataset, iterate_llama3_dataset
|
||||
eval_dataset = get_llama3_dataset(5760, SEQLEN, BASEDIR, val=True, small=bool(SMALL))
|
||||
iter = iterate_llama3_dataset(eval_dataset, BS)
|
||||
|
||||
losses = []
|
||||
for tokens in tqdm(iter, total=5760//BS):
|
||||
|
||||
@@ -1424,23 +1424,20 @@ def train_llama3():
|
||||
if getenv("FAKEDATA", 0):
|
||||
return fake_data(BS, SAMPLES)
|
||||
else:
|
||||
if SMALL:
|
||||
from examples.mlperf.dataloader import batch_load_llama3_small
|
||||
return batch_load_llama3_small(BS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||
else:
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
return batch_load_llama3(BS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
return batch_load_llama3(BS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL), small=bool(SMALL))
|
||||
|
||||
if getenv("FAKEDATA", 0):
|
||||
eval_dataset = None
|
||||
else:
|
||||
from examples.mlperf.dataloader import get_llama3_dataset
|
||||
eval_dataset = get_llama3_dataset(5760, SEQLEN, BASEDIR, val=True, small=bool(SMALL))
|
||||
|
||||
def get_eval_iter():
|
||||
if getenv("FAKEDATA", 0):
|
||||
if eval_dataset is None:
|
||||
return fake_data(EVAL_BS, 5760)
|
||||
else:
|
||||
if SMALL:
|
||||
from examples.mlperf.dataloader import batch_load_llama3_small
|
||||
return batch_load_llama3_small(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True)
|
||||
else:
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
return batch_load_llama3(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True)
|
||||
from examples.mlperf.dataloader import iterate_llama3_dataset
|
||||
return iterate_llama3_dataset(eval_dataset, EVAL_BS)
|
||||
|
||||
iter = get_train_iter()
|
||||
i, sequences_seen = resume_ckpt, 0
|
||||
|
||||
Reference in New Issue
Block a user