mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Global norm + small changes (#4749)
* norm * no empty * default loss scaler in float
This commit is contained in:
@@ -364,7 +364,13 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
|
||||
optimizer.zero_grad()
|
||||
(loss * loss_scaler).backward()
|
||||
|
||||
for p in optimizer.params: p.grad /= loss_scaler
|
||||
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
|
||||
for p in optimizer.params:
|
||||
p.grad = p.grad / loss_scaler
|
||||
global_norm += p.grad.float().square().sum()
|
||||
global_norm = global_norm.sqrt()
|
||||
for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
return loss.realize()
|
||||
@@ -417,9 +423,9 @@ def train_bert():
|
||||
save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
|
||||
init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
|
||||
|
||||
loss_scaler = config["loss_scaler"] = getenv("LOSS_SCALER", 2**9 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
decay = config["decay"] = getenv("DECAY", 0.01)
|
||||
poly_power = config["poly_power"] = getenv("POLY_POWER", 1.0)
|
||||
loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**9 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
decay = config["DECAY"] = getenv("DECAY", 0.01)
|
||||
poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
|
||||
|
||||
target, achieved = getenv("TARGET", 0.72), False
|
||||
|
||||
@@ -444,8 +450,8 @@ def train_bert():
|
||||
# ** Optimizer **
|
||||
parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
|
||||
parameters = [x for x in parameters if x not in set(parameters_no_wd)]
|
||||
optimizer_wd = LAMB(parameters, lr=max_lr, eps=1e-6, wd=decay, adam=False)
|
||||
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=1e-6, wd=0.0, adam=False)
|
||||
optimizer_wd = LAMB(parameters, lr=max_lr, eps=1e-6, weight_decay=decay, adam=False)
|
||||
optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=1e-6, weight_decay=0.0, adam=False)
|
||||
optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
|
||||
|
||||
# ** LR scheduler **
|
||||
@@ -471,7 +477,7 @@ def train_bert():
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
|
||||
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), total=train_steps, disable=BENCHMARK))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
|
||||
|
||||
step_times = []
|
||||
# ** train loop **
|
||||
|
||||
Reference in New Issue
Block a user