Resume fix + scheduler for non weight decay params (#4679)

* move ckpt dir

* fix resume. Add scheduler group
This commit is contained in:
Elias Wahl
2024-05-22 01:38:13 +02:00
committed by GitHub
parent 0f21aa0416
commit acc0039cfc
2 changed files with 24 additions and 19 deletions

View File

@@ -207,14 +207,16 @@ def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]:
return data[file_and_offset[1]]
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
def batch_load_train_bert(BS:int):
def batch_load_train_bert(BS:int, start_step:int = 0):
from extra.datasets.wikipedia import get_wiki_train_files
files = shuffle_parts(get_wiki_train_files())
dataset = []
for f in files:
for f in tqdm(files, desc="Building dataset"):
lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))]
dataset.extend(lists)
dataset = dataset[start_step:]
active_set = deque(dataset[:1000])
remaining_set = deque(dataset[1000:])

View File

@@ -414,6 +414,7 @@ def train_bert():
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
loss_scaler = config["loss_scaler"] = getenv("LOSS_SCALER", 2**9 if dtypes.default_float == dtypes.float16 else 1.0)
@@ -437,25 +438,27 @@ def train_bert():
assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
# ** Log hparams **
# ** Log run config **
for key, value in config.items(): print(f'HParam: "{key}": {value}')
# ** Optimizer **
skip_list = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
parameters = [x for x in parameters if x not in set(skip_list)]
optimizer = LAMB(parameters, lr=max_lr, eps=1e-6, wd=decay, adam=False)
optimizer_skip = LAMB(skip_list, lr=max_lr, eps=1e-6, wd=0.0, adam=False)
optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
parameters = [x for x in parameters if x not in set(parameters_no_wd)]
optimizer_wd = LAMB(parameters, lr=max_lr, eps=1e-6, wd=decay, adam=False)
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=1e-6, wd=0.0, adam=False)
optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
# ** LR scheduler **
scheduler = PolynomialDecayWithWarmup(optimizer, max_lr, 0, train_steps, warmup_steps, power=poly_power)
scheduler_wd = PolynomialDecayWithWarmup(optimizer_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
scheduler_no_wd = PolynomialDecayWithWarmup(optimizer_no_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd)
print(f"training with batch size {BS} for one epoch with {train_steps} steps")
# ** resume from checkpointing **
start_step = 0
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler, safe_load(ckpt))
start_step = scheduler.epoch_counter.numpy().item()
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
start_step = int(scheduler_wd.epoch_counter.numpy().item())
print(f"resuming from {ckpt} at step {start_step}")
# ** init wandb **
@@ -468,18 +471,18 @@ def train_bert():
BENCHMARK = getenv("BENCHMARK")
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), total=train_steps, disable=BENCHMARK))
step_times = []
# ** train loop **
wc_start = time.perf_counter()
i, train_data = 0, get_data_bert(GPUS, train_it)
i, train_data = start_step, get_data_bert(GPUS, train_it)
while train_data is not None and i < train_steps and not achieved:
Tensor.training = True
BEAM.value = TRAIN_BEAM
st = time.perf_counter()
GlobalCounters.reset()
loss = train_step_bert(model, optimizer_group, scheduler, loss_scaler,
loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
@@ -500,10 +503,10 @@ def train_bert():
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer.lr.numpy()[0]:.6f} LR, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
@@ -569,7 +572,7 @@ def train_bert():
# save model if achieved target
if not achieved and avg_lm_acc >= target:
wc_end = time.perf_counter()
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
fn = f"{ckpt_dir}/bert-large.safe"
safe_save(get_state_dict(model), fn)
print(f" *** Model saved to {fn} ***")
@@ -584,13 +587,13 @@ def train_bert():
break
if getenv("CKPT") and i % save_ckpt_freq == 0:
if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir)
if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
if WANDB and wandb.run is not None:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
else:
fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
print(f"saving ckpt to {fn}")
safe_save(get_training_state(model, optimizer_group, scheduler), fn)
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
while len(ckpt_files) > keep_ckpt_amount: