mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-17 18:11:49 -05:00
reduce number of lines (#645)
This commit is contained in:
@@ -6,8 +6,7 @@ class Optimizer:
|
||||
def __init__(self, params : List[Tensor]):
|
||||
# if it's None, but being put into an optimizer, set it to True
|
||||
for x in params:
|
||||
if x.requires_grad is None:
|
||||
x.requires_grad = True
|
||||
if x.requires_grad is None: x.requires_grad = True
|
||||
|
||||
self.params : List[Tensor] = [x for x in params if x.requires_grad]
|
||||
self.buffers : List[Tensor] = [x for x in params if not x.requires_grad] # buffers are still realized
|
||||
@@ -20,8 +19,7 @@ class Optimizer:
|
||||
param.grad.assign(param.grad.clip(-(amount**2), (amount**2)))
|
||||
|
||||
def zero_grad(self):
|
||||
for param in self.params:
|
||||
param.grad = None
|
||||
for param in self.params: param.grad = None
|
||||
|
||||
def realize(self, extra=None):
|
||||
# TODO: corealize
|
||||
@@ -83,9 +81,7 @@ def get_parameters(obj) -> List[Tensor]:
|
||||
if isinstance(obj, Tensor):
|
||||
parameters.append(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for x in obj:
|
||||
parameters.extend(get_parameters(x))
|
||||
for x in obj: parameters.extend(get_parameters(x))
|
||||
elif hasattr(obj, '__dict__'):
|
||||
for v in obj.__dict__.values():
|
||||
parameters.extend(get_parameters(v))
|
||||
for v in obj.__dict__.values(): parameters.extend(get_parameters(v))
|
||||
return parameters
|
||||
|
||||
Reference in New Issue
Block a user