diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index d292c5720a..deaa4271f8 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1346,8 +1346,14 @@ def train_llama3(): loss.realize(lr) return loss, lr - from examples.mlperf.dataloader import batch_load_llama3 - iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL)) + if getenv("FAKEDATA", 0): + def fake_data(): + for _ in range(SAMPLES // GBS): + yield Tensor.randint(GBS, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT) + iter = fake_data() + else: + from examples.mlperf.dataloader import batch_load_llama3 + iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL)) i = 0 for tokens in tqdm(iter, total=SAMPLES//BS):