Global norm + small changes (#4749)

* norm

* no empty

* default loss scaler in float
This commit is contained in:
Elias Wahl
2024-05-28 00:35:27 +02:00
committed by GitHub
parent c7beb36b73
commit c4b0acf095

View File

@@ -364,7 +364,13 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
optimizer.zero_grad()
(loss * loss_scaler).backward()
for p in optimizer.params: p.grad /= loss_scaler
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
for p in optimizer.params:
p.grad = p.grad / loss_scaler
global_norm += p.grad.float().square().sum()
global_norm = global_norm.sqrt()
for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
optimizer.step()
scheduler.step()
return loss.realize()
@@ -417,9 +423,9 @@ def train_bert():
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
loss_scaler = config["loss_scaler"] = getenv("LOSS_SCALER", 2**9 if dtypes.default_float == dtypes.float16 else 1.0)
decay = config["decay"] = getenv("DECAY", 0.01)
poly_power = config["poly_power"] = getenv("POLY_POWER", 1.0)
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**9 if dtypes.default_float == dtypes.float16 else 1.0)
decay = config["DECAY"] = getenv("DECAY", 0.01)
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
target, achieved = getenv("TARGET", 0.72), False
@@ -444,8 +450,8 @@ def train_bert():
# ** Optimizer **
parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
parameters = [x for x in parameters if x not in set(parameters_no_wd)]
optimizer_wd = LAMB(parameters, lr=max_lr, eps=1e-6, wd=decay, adam=False)
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=1e-6, wd=0.0, adam=False)
optimizer_wd = LAMB(parameters, lr=max_lr, eps=1e-6, weight_decay=decay, adam=False)
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=1e-6, weight_decay=0.0, adam=False)
optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
# ** LR scheduler **
@@ -471,7 +477,7 @@ def train_bert():
BENCHMARK = getenv("BENCHMARK")
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), total=train_steps, disable=BENCHMARK))
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
step_times = []
# ** train loop **