touchups from multibuffer branch (#2958)

This commit is contained in:
George Hotz
2024-01-01 11:33:41 -08:00
committed by GitHub
parent 45247385eb
commit e0ecab3797
2 changed files with 8 additions and 1 deletions

View File

@@ -49,7 +49,7 @@ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAM
class LAMB(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
super().__init__(params, lr)
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize()
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], device=self.device, requires_grad=False).realize()
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]