From 6252f7770ee8889eec933bebb9509bf3ea03b4f6 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 30 Jul 2025 17:18:20 -0700 Subject: [PATCH] feat: fake data (#11447) --- examples/mlperf/model_train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index d292c5720a..deaa4271f8 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1346,8 +1346,14 @@ def train_llama3(): loss.realize(lr) return loss, lr - from examples.mlperf.dataloader import batch_load_llama3 - iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL)) + if getenv("FAKEDATA", 0): + def fake_data(): + for _ in range(SAMPLES // GBS): + yield Tensor.randint(GBS, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT) + iter = fake_data() + else: + from examples.mlperf.dataloader import batch_load_llama3 + iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL)) i = 0 for tokens in tqdm(iter, total=SAMPLES//BS):