mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-17 18:11:49 -05:00
add AdamW optimizer (#716)
* add AdamW optimizer * one liner Adam optimizer
This commit is contained in:
@@ -59,11 +59,11 @@ class RMSprop(Optimizer):
|
||||
t.assign(t.detach() - (g * self.lr).div(self.v[i].sqrt() + self.eps))
|
||||
self.realize(self.v)
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
||||
class AdamW(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01):
|
||||
super().__init__(params)
|
||||
# NOTE: self.t is a tensor so Adam can be jitted
|
||||
self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, Tensor([0], requires_grad=False).realize()
|
||||
self.lr, self.b1, self.b2, self.eps, self.wd, self.t = lr, b1, b2, eps, wd, Tensor([0], requires_grad=False).realize()
|
||||
|
||||
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
@@ -76,9 +76,11 @@ class Adam(Optimizer):
|
||||
g = t.grad.realize()
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize()
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize()
|
||||
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
|
||||
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps) - self.lr * self.wd * t.detach())
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
|
||||
def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return AdamW(params, lr, b1, b2, eps, 0.0)
|
||||
|
||||
def get_state_dict(obj, prefix:str='') -> Dict[str, Tensor]:
|
||||
if isinstance(obj, Tensor): return {prefix.strip('.'):obj}
|
||||
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix)
|
||||
|
||||
Reference in New Issue
Block a user