mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Resume fix + scheduler for non weight decay params (#4679)
* move ckpt dir * fix resume. Add scheduler group
This commit is contained in:
@@ -207,14 +207,16 @@ def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]:
|
||||
return data[file_and_offset[1]]
|
||||
|
||||
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
|
||||
def batch_load_train_bert(BS:int):
|
||||
def batch_load_train_bert(BS:int, start_step:int = 0):
|
||||
from extra.datasets.wikipedia import get_wiki_train_files
|
||||
files = shuffle_parts(get_wiki_train_files())
|
||||
dataset = []
|
||||
for f in files:
|
||||
for f in tqdm(files, desc="Building dataset"):
|
||||
lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))]
|
||||
dataset.extend(lists)
|
||||
|
||||
dataset = dataset[start_step:]
|
||||
|
||||
active_set = deque(dataset[:1000])
|
||||
remaining_set = deque(dataset[1000:])
|
||||
|
||||
|
||||
@@ -414,6 +414,7 @@ def train_bert():
|
||||
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
|
||||
save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
|
||||
keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
|
||||
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
|
||||
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
|
||||
|
||||
loss_scaler = config["loss_scaler"] = getenv("LOSS_SCALER", 2**9 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
@@ -437,25 +438,27 @@ def train_bert():
|
||||
|
||||
assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
|
||||
|
||||
# ** Log hparams **
|
||||
# ** Log run config **
|
||||
for key, value in config.items(): print(f'HParam: "{key}": {value}')
|
||||
|
||||
# ** Optimizer **
|
||||
skip_list = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
|
||||
parameters = [x for x in parameters if x not in set(skip_list)]
|
||||
optimizer = LAMB(parameters, lr=max_lr, eps=1e-6, wd=decay, adam=False)
|
||||
optimizer_skip = LAMB(skip_list, lr=max_lr, eps=1e-6, wd=0.0, adam=False)
|
||||
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
|
||||
parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
|
||||
parameters = [x for x in parameters if x not in set(parameters_no_wd)]
|
||||
optimizer_wd = LAMB(parameters, lr=max_lr, eps=1e-6, wd=decay, adam=False)
|
||||
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=1e-6, wd=0.0, adam=False)
|
||||
optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
|
||||
|
||||
# ** LR scheduler **
|
||||
scheduler = PolynomialDecayWithWarmup(optimizer, max_lr, 0, train_steps, warmup_steps, power=poly_power)
|
||||
scheduler_wd = PolynomialDecayWithWarmup(optimizer_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
|
||||
scheduler_no_wd = PolynomialDecayWithWarmup(optimizer_no_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
|
||||
scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd)
|
||||
print(f"training with batch size {BS} for one epoch with {train_steps} steps")
|
||||
|
||||
# ** resume from checkpointing **
|
||||
start_step = 0
|
||||
if ckpt:=getenv("RESUME", ""):
|
||||
load_training_state(model, optimizer_group, scheduler, safe_load(ckpt))
|
||||
start_step = scheduler.epoch_counter.numpy().item()
|
||||
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
|
||||
start_step = int(scheduler_wd.epoch_counter.numpy().item())
|
||||
print(f"resuming from {ckpt} at step {start_step}")
|
||||
|
||||
# ** init wandb **
|
||||
@@ -468,18 +471,18 @@ def train_bert():
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
|
||||
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), total=train_steps, disable=BENCHMARK))
|
||||
|
||||
step_times = []
|
||||
# ** train loop **
|
||||
wc_start = time.perf_counter()
|
||||
i, train_data = 0, get_data_bert(GPUS, train_it)
|
||||
i, train_data = start_step, get_data_bert(GPUS, train_it)
|
||||
while train_data is not None and i < train_steps and not achieved:
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
st = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss = train_step_bert(model, optimizer_group, scheduler, loss_scaler,
|
||||
loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
|
||||
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
||||
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
|
||||
|
||||
@@ -500,10 +503,10 @@ def train_bert():
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer.lr.numpy()[0]:.6f} LR, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
|
||||
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
|
||||
if WANDB:
|
||||
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
|
||||
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
|
||||
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
|
||||
|
||||
@@ -569,7 +572,7 @@ def train_bert():
|
||||
# save model if achieved target
|
||||
if not achieved and avg_lm_acc >= target:
|
||||
wc_end = time.perf_counter()
|
||||
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
|
||||
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
|
||||
fn = f"{ckpt_dir}/bert-large.safe"
|
||||
safe_save(get_state_dict(model), fn)
|
||||
print(f" *** Model saved to {fn} ***")
|
||||
@@ -584,13 +587,13 @@ def train_bert():
|
||||
break
|
||||
|
||||
if getenv("CKPT") and i % save_ckpt_freq == 0:
|
||||
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
|
||||
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
|
||||
if WANDB and wandb.run is not None:
|
||||
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
|
||||
else:
|
||||
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
|
||||
print(f"saving ckpt to {fn}")
|
||||
safe_save(get_training_state(model, optimizer_group, scheduler), fn)
|
||||
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
|
||||
ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
|
||||
ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
|
||||
while len(ckpt_files) > keep_ckpt_amount:
|
||||
|
||||
Reference in New Issue
Block a user