feat: fake data (#11447)

This commit is contained in:
wozeparrot
2025-07-30 17:18:20 -07:00
committed by GitHub
parent e300451f3a
commit 6252f7770e

View File

@@ -1346,8 +1346,14 @@ def train_llama3():
loss.realize(lr)
return loss, lr
from examples.mlperf.dataloader import batch_load_llama3
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
if getenv("FAKEDATA", 0):
def fake_data():
for _ in range(SAMPLES // GBS):
yield Tensor.randint(GBS, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT)
iter = fake_data()
else:
from examples.mlperf.dataloader import batch_load_llama3
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
i = 0
for tokens in tqdm(iter, total=SAMPLES//BS):