bert grad clipping start with const 0 (#11008)

saved the init kernels
This commit is contained in:
chenyu
2025-06-27 18:02:23 -04:00
committed by GitHub
parent a6485d00c8
commit f2548afeb5

View File

@@ -933,7 +933,7 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, GPUS, grad_a
# TODO: OOM without this realize with large grad_acc
Tensor.realize(*[p.grad for p in optimizer.params])
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device)
global_norm = Tensor(0.0, dtype=dtypes.float32, device=optimizer[0].device)
for p in optimizer.params:
p.grad = p.grad / loss_scaler
global_norm += p.grad.float().square().sum()