mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
TRAIN=0 to only eval llama (#13804)
This commit is contained in:
@@ -1438,33 +1438,34 @@ def train_llama3():
|
|||||||
iter = get_train_iter()
|
iter = get_train_iter()
|
||||||
i, sequences_seen = resume_ckpt, 0
|
i, sequences_seen = resume_ckpt, 0
|
||||||
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
||||||
t = time.perf_counter()
|
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
loss, lr = train_step(model, tokens)
|
if getenv("TRAIN", 1):
|
||||||
loss = loss.float().item()
|
t = time.perf_counter()
|
||||||
lr = lr.item()
|
loss, lr = train_step(model, tokens)
|
||||||
|
loss = loss.float().item()
|
||||||
|
lr = lr.item()
|
||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
sequences_seen += tokens.shape[0]
|
sequences_seen += tokens.shape[0]
|
||||||
|
|
||||||
sec = time.perf_counter()-t
|
sec = time.perf_counter()-t
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
|
f"{i:5} {sec:.2f} s run, {loss:.4f} loss, {lr:.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, "
|
||||||
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
|
f"{GlobalCounters.global_ops * 1e-9 / sec:9.2f} GFLOPS")
|
||||||
|
|
||||||
if (fname:=getenv("LOSS_FILE", "")):
|
if (fname:=getenv("LOSS_FILE", "")):
|
||||||
with open(fname, "a") as f:
|
with open(fname, "a") as f:
|
||||||
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
f.write(f"{i} {loss:.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||||
|
|
||||||
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
|
if (ckpt_freq := getenv("CKPT")) and (i % ckpt_freq == 0 and (i != 1 or ckpt_freq == 1)):
|
||||||
tqdm.write("saving checkpoint")
|
tqdm.write("saving checkpoint")
|
||||||
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
if not os.path.exists(ckpt_dir := "./ckpts"): os.mkdir(ckpt_dir)
|
||||||
fn = f"{ckpt_dir}/llama3_{i}.safe"
|
fn = f"{ckpt_dir}/llama3_{i}.safe"
|
||||||
safe_save(get_state_dict(model), fn)
|
safe_save(get_state_dict(model), fn)
|
||||||
|
|
||||||
tqdm.write("saving optim checkpoint")
|
tqdm.write("saving optim checkpoint")
|
||||||
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
|
fn = f"{ckpt_dir}/llama3_{i}_optim.safe"
|
||||||
safe_save(get_state_dict(scheduler), fn)
|
safe_save(get_state_dict(scheduler), fn)
|
||||||
|
|
||||||
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
|
if sequences_seen % EVAL_FREQ == 0 and (i != 1 or EVAL_FREQ == 1):
|
||||||
tqdm.write(f"evaluating after {sequences_seen} sequences")
|
tqdm.write(f"evaluating after {sequences_seen} sequences")
|
||||||
|
|||||||
Reference in New Issue
Block a user