mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add seed in bert data shuffle (#14054)
This commit is contained in:
@@ -213,12 +213,13 @@ class InterleavedDataset:
|
||||
self.queues[queue_index].queue.extend(load_file(file))
|
||||
|
||||
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
|
||||
def batch_load_train_bert(BS:int):
|
||||
def batch_load_train_bert(BS:int, seed:int|None=None):
|
||||
from extra.datasets.wikipedia import get_wiki_train_files
|
||||
rng = random.Random(seed)
|
||||
fs = sorted(get_wiki_train_files())
|
||||
train_files = []
|
||||
while fs: # TF shuffle
|
||||
random.shuffle(fs)
|
||||
rng.shuffle(fs)
|
||||
train_files.append(fs.pop(0))
|
||||
|
||||
cycle_length = min(getenv("NUM_CPU_THREADS", min(os.cpu_count(), 8)), len(train_files))
|
||||
|
||||
@@ -1085,7 +1085,7 @@ def train_bert():
|
||||
if RUNMLPERF:
|
||||
# only load real data with RUNMLPERF
|
||||
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS, seed=seed), total=train_steps, disable=BENCHMARK))
|
||||
for _ in range(start_step): next(train_it) # Fast forward
|
||||
else:
|
||||
# repeat fake data
|
||||
|
||||
Reference in New Issue
Block a user