feat: flag for training on val (#11441)

This commit is contained in:
wozeparrot
2025-07-30 14:29:45 -07:00
committed by GitHub
parent 4ca430e5bf
commit 5fb975351a

View File

@@ -1296,6 +1296,7 @@ def train_llama3():
SEED = config["SEED"] = getenv("SEED", 5760)
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 1_200_000)
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
opt_adamw_beta_1 = 0.9
opt_adamw_beta_2 = 0.95
@@ -1342,9 +1343,8 @@ def train_llama3():
loss.realize(lr)
return loss, lr
# overfitting this example should give cross_entropy log(BS)
from examples.mlperf.dataloader import batch_load_llama3
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=False)
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
i = 0
for tokens in tqdm(iter, total=SAMPLES//BS):