mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
log layer stats
This commit is contained in:
@@ -98,9 +98,10 @@ def train_resnet():
|
||||
|
||||
# ** Optimizer **
|
||||
if getenv("LARS", 1):
|
||||
optimizer = optim.LARS(parameters, base_lr / lr_scaler, momentum=.9, weight_decay=decay, track_gnorm=True)
|
||||
optimizer = optim.LARS(parameters, base_lr / lr_scaler, momentum=.9, weight_decay=decay, track_norms=True)
|
||||
else:
|
||||
optimizer = optim.SGD(parameters, base_lr / lr_scaler, momentum=.9, weight_decay=decay)
|
||||
parameter_keys = [k for k, v in get_state_dict(model).items() if v.requires_grad]
|
||||
|
||||
# ** LR scheduler **
|
||||
# scheduler = MultiStepLR(optimizer, [m for m in lr_steps], gamma=lr_gamma, warmup=lr_warmup)
|
||||
@@ -201,7 +202,7 @@ def train_resnet():
|
||||
# the backward step should be realized by loss.numpy(), even though it doesn't depend on this.
|
||||
# doing this uses 16.38gb vs 15.55gb? why? because the grads get realized in optimizer.step, and the backward buffers are freed?
|
||||
fwet = time.perf_counter()
|
||||
gnorm = backward_step(*proc[1], proc[0][0])
|
||||
wnorm, wnorms, gnorm, gnorms = backward_step(*proc[1], proc[0][0])
|
||||
# proc = (proc[0], proc[2]) # drop inputs
|
||||
|
||||
et = time.perf_counter()
|
||||
@@ -215,8 +216,10 @@ def train_resnet():
|
||||
dte = time.perf_counter()
|
||||
|
||||
device_str = proc[0][2].device if isinstance(proc[0][2].device, str) else f"{proc[0][2].device[0]} * {len(proc[0][2].device)}"
|
||||
proc, top_1_acc, gnorm = proc[0][0].numpy(), proc[0][2].numpy().item() / BS, gnorm.numpy() # return cookie
|
||||
proc, top_1_acc, wnorm, wnorms, gnorm, gnorms = proc[0][0].numpy(), proc[0][2].numpy().item() / BS, wnorm.numpy(), [x.numpy() for x in wnorms], gnorm.numpy(), [x.numpy() for x in gnorms] # return cookie
|
||||
loss_cpu = proc / lr_scaler
|
||||
wnorms = {f"wnorms/{k}": v for k, v in zip(parameter_keys, wnorms)}
|
||||
gnorms = {f"gnorms/{k}": v for k, v in zip(parameter_keys, gnorms)}
|
||||
cl = time.perf_counter()
|
||||
new_st = time.perf_counter()
|
||||
|
||||
@@ -229,9 +232,12 @@ def train_resnet():
|
||||
"train/cl_time": cl - dte,
|
||||
"train/loss": loss_cpu,
|
||||
"train/top_1_acc": top_1_acc,
|
||||
"train/wnorm": wnorm,
|
||||
"train/gnorm": gnorm,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st),
|
||||
"epoch": e + (i + 1) / steps_in_train_epoch,
|
||||
**wnorms,
|
||||
**gnorms,
|
||||
})
|
||||
|
||||
st = new_st
|
||||
@@ -285,11 +291,14 @@ def train_resnet():
|
||||
total_top_5 = sum(eval_top_5_acc) / len(eval_top_5_acc)
|
||||
total_fw_time = sum(eval_times) / len(eval_times)
|
||||
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}, eval top 5 acc: {total_top_5:.3f}")
|
||||
|
||||
weight_hists = {f"weights/{k}": wandb.Hist(v.numpy().flatten().tolist()) for k, v in get_state_dict(model)}
|
||||
wandb.log({"eval/loss": total_loss,
|
||||
"eval/top_1_acc": total_top_1,
|
||||
"eval/top_5_acc": total_top_5,
|
||||
"eval/forward_time": total_fw_time,
|
||||
"eval/forward_time": total_fw_time,
|
||||
"epoch": e + 1,
|
||||
**weight_hists,
|
||||
})
|
||||
|
||||
if not achieved and total_top_1 >= target:
|
||||
|
||||
@@ -73,13 +73,14 @@ class LAMB(Optimizer):
|
||||
|
||||
# https://github.com/mlcommons/training/blob/master/image_classification/tensorflow2/lars_optimizer.py
|
||||
class LARS(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr, momentum=0.9, weight_decay=1e-4, eta=0.001, eps=0.0, nesterov=False, track_gnorm=False):
|
||||
def __init__(self, params: List[Tensor], lr, momentum=0.9, weight_decay=1e-4, eta=0.001, eps=0.0, nesterov=False, track_gnorm=False, track_norms=False):
|
||||
super().__init__(params, lr)
|
||||
self.momentum, self.weight_decay, self.eta, self.eps, self.nesterov, self.track_gnorm = momentum, weight_decay, eta, eps, nesterov, track_gnorm
|
||||
self.momentum, self.weight_decay, self.eta, self.eps, self.nesterov, self.track_norms = momentum, weight_decay, eta, eps, nesterov, track_norms
|
||||
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
|
||||
def step(self):
|
||||
gnorm = 0
|
||||
wnorm, gnorm = 0, 0
|
||||
wnorms, gnorms = [], []
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
# this is needed since the grads can form a "diamond"
|
||||
@@ -87,8 +88,14 @@ class LARS(Optimizer):
|
||||
t.grad.realize()
|
||||
t_ = t.detach()
|
||||
w_norm = (t_ * t_).sum().sqrt()
|
||||
if self.track_norms:
|
||||
wnorms.append(w_norm.to("HIP"))
|
||||
wnorm = wnorm + (w_norm*w_norm).to("HIP")
|
||||
g_norm = (t.grad * t.grad).sum().sqrt()
|
||||
if self.track_gnorm: gnorm = gnorm + g_norm.to("HIP")
|
||||
if self.track_norms:
|
||||
gnorms.append(g_norm.to("HIP"))
|
||||
gnorm = gnorm + (g_norm*g_norm).to("HIP")
|
||||
|
||||
trust_ratio = (w_norm > 0).where(
|
||||
(g_norm > 0).where(
|
||||
self.eta * w_norm / (g_norm + self.weight_decay * w_norm + self.eps), 1.0
|
||||
@@ -102,5 +109,5 @@ class LARS(Optimizer):
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
t.assign(t.detach() - g)
|
||||
self.realize(self.b)
|
||||
if self.track_gnorm: return gnorm.realize()
|
||||
if self.track_norms: return wnorm.sqrt().realize(), [x.realize() for x in wnorms], gnorm.sqrt().realize(), [x.realize() for x in gnorms]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user