From c0ea538ba09abf1a8933e058bdbdcbab1283188a Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 10 Feb 2023 23:06:00 -0600 Subject: [PATCH] Revert "revert t as tensor, constant folding should be done better" This reverts commit 1d800a94ad37ce383cd938db01d15a0c13e76b0d. --- tinygrad/nn/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 35f3fc804d..0c5eed2e37 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -55,7 +55,7 @@ class RMSprop(Optimizer): class Adam(Optimizer): def __init__(self, params : List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): super().__init__(params) - self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, 0 + self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, Tensor([0], requires_grad=False).realize() self.m = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params] self.v = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params] @@ -68,7 +68,7 @@ class Adam(Optimizer): self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad) self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)) t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps)) - self.realize(self.m + self.v) + self.realize([self.t] + self.m + self.v) def get_parameters(obj) -> List[Tensor]: parameters : List[Tensor] = []