Add weight decay to SGD (#883)

* feat: add weight decay to sgd

* fix: fix tests
This commit is contained in:
wozeparrot
2023-06-01 16:13:18 -04:00
committed by GitHub
parent 0e9e0fd718
commit bfea5215e9
2 changed files with 13 additions and 4 deletions

View File

@@ -21,16 +21,16 @@ class Optimizer:
p.realize()
class SGD(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, nesterov=False):
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
super().__init__(params)
self.lr, self.momentum, self.nesterov = lr, momentum, nesterov
self.lr, self.momentum, self.wd, self.nesterov = lr, momentum, weight_decay, nesterov
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
def step(self) -> None:
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.realize()
g = t.grad.realize() + self.wd * t.detach()
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]