This commit is contained in:
David Hou
2024-02-13 13:43:01 -08:00
parent 9d98224585
commit 988f4c1cf3
2 changed files with 5 additions and 3 deletions

View File

@@ -292,7 +292,7 @@ def train_resnet():
total_fw_time = sum(eval_times) / len(eval_times)
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}, eval top 5 acc: {total_top_5:.3f}")
weight_hists = {f"weights/{k}": wandb.Hist(v.numpy().flatten().tolist()) for k, v in get_state_dict(model)}
weight_hists = {f"weights/{k}": wandb.Histogram(v.numpy().flatten().tolist()) for k, v in get_state_dict(model).items() if v.requires_grad}
wandb.log({"eval/loss": total_loss,
"eval/top_1_acc": total_top_1,
"eval/top_5_acc": total_top_5,

View File

@@ -108,6 +108,8 @@ class LARS(Optimizer):
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g)
self.realize(self.b)
if self.track_norms: return wnorm.sqrt().realize(), [x.realize() for x in wnorms], gnorm.sqrt().realize(), [x.realize() for x in gnorms]
wnorm = wnorm.sqrt()
gnorm = gnorm.sqrt()
self.realize(self.b + [wnorm, gnorm] + wnorms + gnorms)
if self.track_norms: return wnorm, wnorms, gnorm, gnorms