mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
small
This commit is contained in:
@@ -292,7 +292,7 @@ def train_resnet():
|
||||
total_fw_time = sum(eval_times) / len(eval_times)
|
||||
tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}, eval top 5 acc: {total_top_5:.3f}")
|
||||
|
||||
weight_hists = {f"weights/{k}": wandb.Hist(v.numpy().flatten().tolist()) for k, v in get_state_dict(model)}
|
||||
weight_hists = {f"weights/{k}": wandb.Histogram(v.numpy().flatten().tolist()) for k, v in get_state_dict(model).items() if v.requires_grad}
|
||||
wandb.log({"eval/loss": total_loss,
|
||||
"eval/top_1_acc": total_top_1,
|
||||
"eval/top_5_acc": total_top_5,
|
||||
|
||||
@@ -108,6 +108,8 @@ class LARS(Optimizer):
|
||||
self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
t.assign(t.detach() - g)
|
||||
self.realize(self.b)
|
||||
if self.track_norms: return wnorm.sqrt().realize(), [x.realize() for x in wnorms], gnorm.sqrt().realize(), [x.realize() for x in gnorms]
|
||||
wnorm = wnorm.sqrt()
|
||||
gnorm = gnorm.sqrt()
|
||||
self.realize(self.b + [wnorm, gnorm] + wnorms + gnorms)
|
||||
if self.track_norms: return wnorm, wnorms, gnorm, gnorms
|
||||
|
||||
|
||||
Reference in New Issue
Block a user