diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 35f3fc804d..0c5eed2e37 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -55,7 +55,7 @@ class RMSprop(Optimizer): class Adam(Optimizer): def __init__(self, params : List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): super().__init__(params) - self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, 0 + self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, Tensor([0], requires_grad=False).realize() self.m = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params] self.v = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params] @@ -68,7 +68,7 @@ class Adam(Optimizer): self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad) self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)) t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps)) - self.realize(self.m + self.v) + self.realize([self.t] + self.m + self.v) def get_parameters(obj) -> List[Tensor]: parameters : List[Tensor] = []