add seed in bert data shuffle (#14054)

This commit is contained in:
b1tg
2026-01-07 23:02:05 +08:00
committed by GitHub
parent 25c82dd242
commit 241f0402b4
2 changed files with 4 additions and 3 deletions

View File

@@ -213,12 +213,13 @@ class InterleavedDataset:
self.queues[queue_index].queue.extend(load_file(file))
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
def batch_load_train_bert(BS:int):
def batch_load_train_bert(BS:int, seed:int|None=None):
from extra.datasets.wikipedia import get_wiki_train_files
rng = random.Random(seed)
fs = sorted(get_wiki_train_files())
train_files = []
while fs: # TF shuffle
random.shuffle(fs)
rng.shuffle(fs)
train_files.append(fs.pop(0))
cycle_length = min(getenv("NUM_CPU_THREADS", min(os.cpu_count(), 8)), len(train_files))

View File

@@ -1085,7 +1085,7 @@ def train_bert():
if RUNMLPERF:
# only load real data with RUNMLPERF
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
train_it = iter(tqdm(batch_load_train_bert(BS, seed=seed), total=train_steps, disable=BENCHMARK))
for _ in range(start_step): next(train_it) # Fast forward
else:
# repeat fake data