mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
feat: fake data (#11447)
This commit is contained in:
@@ -1346,8 +1346,14 @@ def train_llama3():
|
||||
loss.realize(lr)
|
||||
return loss, lr
|
||||
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||
if getenv("FAKEDATA", 0):
|
||||
def fake_data():
|
||||
for _ in range(SAMPLES // GBS):
|
||||
yield Tensor.randint(GBS, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT)
|
||||
iter = fake_data()
|
||||
else:
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||
|
||||
i = 0
|
||||
for tokens in tqdm(iter, total=SAMPLES//BS):
|
||||
|
||||
Reference in New Issue
Block a user