mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
nn/optim.py compiles now
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
# sorted in order of increasing complexity
|
||||
from typing import List
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, params):
|
||||
def __init__(self, params : List[Tensor]):
|
||||
# if it's None, but being put into an optimizer, set it to True
|
||||
for x in params:
|
||||
if x.requires_grad is None:
|
||||
x.requires_grad = True
|
||||
|
||||
self.params = [x for x in params if x.requires_grad]
|
||||
self.params : List[Tensor] = [x for x in params if x.requires_grad]
|
||||
|
||||
# TODO: this probably shouldn't change the gradients, just the ones used by the optimizer
|
||||
def clipnorm(self, amount=1):
|
||||
@@ -26,7 +27,7 @@ class Optimizer:
|
||||
p.realize()
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params, lr=0.001):
|
||||
def __init__(self, params : List[Tensor], lr=0.001):
|
||||
super().__init__(params)
|
||||
self.lr = lr
|
||||
|
||||
@@ -36,7 +37,7 @@ class SGD(Optimizer):
|
||||
self.realize()
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
def __init__(self, params, lr=0.001, decay=0.9, eps=1e-8):
|
||||
def __init__(self, params : List[Tensor], lr=0.001, decay=0.9, eps=1e-8):
|
||||
super().__init__(params)
|
||||
self.lr, self.decay, self.eps = lr, decay, eps
|
||||
|
||||
@@ -49,7 +50,7 @@ class RMSprop(Optimizer):
|
||||
self.realize(self.v)
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, params, lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
||||
def __init__(self, params : List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
||||
super().__init__(params)
|
||||
self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, Tensor([0], requires_grad=False).realize()
|
||||
|
||||
@@ -65,8 +66,8 @@ class Adam(Optimizer):
|
||||
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
|
||||
def get_parameters(obj):
|
||||
parameters = []
|
||||
def get_parameters(obj) -> List[Tensor]:
|
||||
parameters : List[Tensor] = []
|
||||
if isinstance(obj, Tensor):
|
||||
parameters.append(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
|
||||
Reference in New Issue
Block a user