log layer stats

This commit is contained in:
David Hou
2024-02-13 13:06:55 -08:00
parent e51088624c
commit 9d98224585
2 changed files with 25 additions and 9 deletions

View File

@@ -98,9 +98,10 @@ def train_resnet():
# ** Optimizer **
if getenv("LARS", 1):
optimizer = optim.LARS(parameters, base_lr / lr_scaler, momentum=.9, weight_decay=decay, track_gnorm=True)
optimizer = optim.LARS(parameters, base_lr / lr_scaler, momentum=.9, weight_decay=decay, track_norms=True)
else:
optimizer = optim.SGD(parameters, base_lr / lr_scaler, momentum=.9, weight_decay=decay)
parameter_keys = [k for k, v in get_state_dict(model).items() if v.requires_grad]
# ** LR scheduler **
# scheduler = MultiStepLR(optimizer, [m for m in lr_steps], gamma=lr_gamma, warmup=lr_warmup)
@@ -201,7 +202,7 @@ def train_resnet():
# the backward step should be realized by loss.numpy(), even though it doesn't depend on this.
# doing this uses 16.38gb vs 15.55gb? why? because the grads get realized in optimizer.step, and the backward buffers are freed?
fwet = time.perf_counter()
gnorm = backward_step(*proc[1], proc[0][0])
wnorm, wnorms, gnorm, gnorms = backward_step(*proc[1], proc[0][0])
# proc = (proc[0], proc[2]) # drop inputs
et = time.perf_counter()
@@ -215,8 +216,10 @@ def train_resnet():
dte = time.perf_counter()
device_str = proc[0][2].device if isinstance(proc[0][2].device, str) else f"{proc[0][2].device[0]} * {len(proc[0][2].device)}"
proc, top_1_acc, gnorm = proc[0][0].numpy(), proc[0][2].numpy().item() / BS, gnorm.numpy() # return cookie
proc, top_1_acc, wnorm, wnorms, gnorm, gnorms = proc[0][0].numpy(), proc[0][2].numpy().item() / BS, wnorm.numpy(), [x.numpy() for x in wnorms], gnorm.numpy(), [x.numpy() for x in gnorms] # return cookie
loss_cpu = proc / lr_scaler
wnorms = {f"wnorms/{k}": v for k, v in zip(parameter_keys, wnorms)}
gnorms = {f"gnorms/{k}": v for k, v in zip(parameter_keys, gnorms)}
cl = time.perf_counter()
new_st = time.perf_counter()
@@ -229,9 +232,12 @@ def train_resnet():
"train/cl_time": cl - dte,
"train/loss": loss_cpu,
"train/top_1_acc": top_1_acc,
"train/wnorm": wnorm,
"train/gnorm": gnorm,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st),
"epoch": e + (i + 1) / steps_in_train_epoch,
**wnorms,
**gnorms,
})
st = new_st
@@ -285,11 +291,14 @@ def train_resnet():
total_top_5 = sum(eval_top_5_acc) / len(eval_top_5_acc)
total_fw_time = sum(eval_times) / len(eval_times)
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}, eval top 5 acc: {total_top_5:.3f}")
weight_hists = {f"weights/{k}": wandb.Hist(v.numpy().flatten().tolist()) for k, v in get_state_dict(model)}
wandb.log({"eval/loss": total_loss,
"eval/top_1_acc": total_top_1,
"eval/top_5_acc": total_top_5,
"eval/forward_time": total_fw_time,
"eval/forward_time": total_fw_time,
"epoch": e + 1,
**weight_hists,
})
if not achieved and total_top_1 >= target:

View File

@@ -73,13 +73,14 @@ class LAMB(Optimizer):
# https://github.com/mlcommons/training/blob/master/image_classification/tensorflow2/lars_optimizer.py
class LARS(Optimizer):
def __init__(self, params: List[Tensor], lr, momentum=0.9, weight_decay=1e-4, eta=0.001, eps=0.0, nesterov=False, track_gnorm=False):
def __init__(self, params: List[Tensor], lr, momentum=0.9, weight_decay=1e-4, eta=0.001, eps=0.0, nesterov=False, track_gnorm=False, track_norms=False):
super().__init__(params, lr)
self.momentum, self.weight_decay, self.eta, self.eps, self.nesterov, self.track_gnorm = momentum, weight_decay, eta, eps, nesterov, track_gnorm
self.momentum, self.weight_decay, self.eta, self.eps, self.nesterov, self.track_norms = momentum, weight_decay, eta, eps, nesterov, track_norms
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
def step(self):
gnorm = 0
wnorm, gnorm = 0, 0
wnorms, gnorms = [], []
for i, t in enumerate(self.params):
assert t.grad is not None
# this is needed since the grads can form a "diamond"
@@ -87,8 +88,14 @@ class LARS(Optimizer):
t.grad.realize()
t_ = t.detach()
w_norm = (t_ * t_).sum().sqrt()
if self.track_norms:
wnorms.append(w_norm.to("HIP"))
wnorm = wnorm + (w_norm*w_norm).to("HIP")
g_norm = (t.grad * t.grad).sum().sqrt()
if self.track_gnorm: gnorm = gnorm + g_norm.to("HIP")
if self.track_norms:
gnorms.append(g_norm.to("HIP"))
gnorm = gnorm + (g_norm*g_norm).to("HIP")
trust_ratio = (w_norm > 0).where(
(g_norm > 0).where(
self.eta * w_norm / (g_norm + self.weight_decay * w_norm + self.eps), 1.0
@@ -102,5 +109,5 @@ class LARS(Optimizer):
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g)
self.realize(self.b)
if self.track_gnorm: return gnorm.realize()
if self.track_norms: return wnorm.sqrt().realize(), [x.realize() for x in wnorms], gnorm.sqrt().realize(), [x.realize() for x in gnorms]