mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
TRAIN=0 to only eval llama (#13804)
This commit is contained in:
@@ -1438,33 +1438,34 @@ def train_llama3():
|
||||
iter = get_train_iter()
|
||||
i, sequences_seen = resume_ckpt, 0
|
||||
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
||||
t = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss, lr = train_step(model, tokens)
|
||||
loss = loss.float().item()
|
||||
lr = lr.item()
|
||||
if getenv("TRAIN", 1):
|
||||
t = time.perf_counter()
|
||||
loss, lr = train_step(model, tokens)
|
||||
loss = loss.float().item()
|
||||
lr = lr.item()
|
||||
|
||||
i += 1
|
||||
sequences_seen += tokens.shape[0]
|
||||
i += 1
|
||||
sequences_seen += tokens.shape[0]
|
||||
|
||||
sec = time.perf_counter()-t
|
||||
tqdm.write(
|
||||
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
|
||||
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
|
||||
sec = time.perf_counter()-t
|
||||
tqdm.write(
|
||||
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
|
||||
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
|
||||
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
|
||||
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
|
||||
tqdm.write("saving checkpoint")
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/llama3_{i}.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
|
||||
tqdm.write("saving checkpoint")
|
||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/llama3_{i}.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
|
||||
tqdm.write("saving optim checkpoint")
|
||||
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
|
||||
safe_save(get_state_dict(scheduler), fn)
|
||||
tqdm.write("saving optim checkpoint")
|
||||
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
|
||||
safe_save(get_state_dict(scheduler), fn)
|
||||
|
||||
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
|
||||
tqdm.write(f"evaluating after {sequences_seen} sequences")
|
||||
|
||||
Reference in New Issue
Block a user