momentum support in SGD

This commit is contained in:
George Hotz
2023-02-11 10:22:37 -08:00
parent 0a2035e015
commit 9152bb5b4a
2 changed files with 13 additions and 6 deletions

View File

@@ -28,15 +28,21 @@ class Optimizer:
p.realize()
class SGD(Optimizer):
def __init__(self, params : List[Tensor], lr=0.001):
def __init__(self, params : List[Tensor], lr=0.001, momentum=0, nesterov=False):
super().__init__(params)
self.lr = lr
self.lr, self.momentum, self.nesterov = lr, momentum, nesterov
self.b = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params] if self.momentum else []
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
def step(self) -> None:
for t in self.params:
for i, t in enumerate(self.params):
assert t.grad is not None
t.assign(t.detach() - t.grad * self.lr)
self.realize()
g = t.grad
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g)
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g * self.lr)
self.realize(self.b)
class RMSprop(Optimizer):
def __init__(self, params : List[Tensor], lr=0.001, decay=0.9, eps=1e-8):