diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 83d1ab853c..3b394bcc92 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -292,7 +292,7 @@ def train_resnet(): total_fw_time = sum(eval_times) / len(eval_times) tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}, eval top 5 acc: {total_top_5:.3f}") - weight_hists = {f"weights/{k}": wandb.Hist(v.numpy().flatten().tolist()) for k, v in get_state_dict(model)} + weight_hists = {f"weights/{k}": wandb.Histogram(v.numpy().flatten().tolist()) for k, v in get_state_dict(model).items() if v.requires_grad} wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/top_5_acc": total_top_5, diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 3239c10dd0..7532eed339 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -108,6 +108,8 @@ class LARS(Optimizer): self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i] t.assign(t.detach() - g) - self.realize(self.b) - if self.track_norms: return wnorm.sqrt().realize(), [x.realize() for x in wnorms], gnorm.sqrt().realize(), [x.realize() for x in gnorms] + wnorm = wnorm.sqrt() + gnorm = gnorm.sqrt() + self.realize(self.b + [wnorm, gnorm] + wnorms + gnorms) + if self.track_norms: return wnorm, wnorms, gnorm, gnorms