mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 02:21:40 -05:00
touchups from multibuffer branch (#2958)
This commit is contained in:
@@ -49,7 +49,7 @@ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAM
|
||||
class LAMB(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
|
||||
super().__init__(params, lr)
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize()
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], device=self.device, requires_grad=False).realize()
|
||||
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user