mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
feat: flag for training on val (#11441)
This commit is contained in:
@@ -1296,6 +1296,7 @@ def train_llama3():
|
||||
SEED = config["SEED"] = getenv("SEED", 5760)
|
||||
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 1_200_000)
|
||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
|
||||
|
||||
opt_adamw_beta_1 = 0.9
|
||||
opt_adamw_beta_2 = 0.95
|
||||
@@ -1342,9 +1343,8 @@ def train_llama3():
|
||||
loss.realize(lr)
|
||||
return loss, lr
|
||||
|
||||
# overfitting this example should give cross_entropy log(BS)
|
||||
from examples.mlperf.dataloader import batch_load_llama3
|
||||
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=False)
|
||||
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||
|
||||
i = 0
|
||||
for tokens in tqdm(iter, total=SAMPLES//BS):
|
||||
|
||||
Reference in New Issue
Block a user