mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
* lars optimizer + tests * fix skip list! * use id to compare in skip list * go back to using set * Tensor(bool) * Tensor(bool) is and * don't lint external/mlperf_resnet * whitespace * add external_test_optim to opencl tests * give mlperf task a name * mlperf under onnx * remove track_gnorm * contiguous instead of realize * assert momentum and weight decay positive --------- Co-authored-by: chenyu <chenyu@fastmail.com>
39 lines
1.4 KiB
Python
39 lines
1.4 KiB
Python
from typing import List, Set
|
|
|
|
from tinygrad import Tensor
|
|
from tinygrad.nn.optim import Optimizer
|
|
|
|
# https://github.com/mlcommons/training/blob/master/image_classification/tensorflow2/lars_optimizer.py
|
|
class LARS(Optimizer):
|
|
def __init__(self, params: List[Tensor], lr, momentum=0.9, weight_decay=1e-4, eta=0.001, eps=0.0, skip_list=None, nesterov=False):
|
|
super().__init__(params, lr)
|
|
assert momentum >= 0.0 and weight_decay >= 0.0
|
|
self.momentum, self.weight_decay, self.eta, self.eps, self.nesterov = momentum, weight_decay, eta, eps, nesterov
|
|
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
|
self.skip_list = set(skip_list or [])
|
|
|
|
def step(self):
|
|
for i, t in enumerate(self.params):
|
|
assert t.grad is not None
|
|
g = t.grad.contiguous()
|
|
w = t.detach()
|
|
|
|
if t not in self.skip_list:
|
|
g_norm = (g * g).sum().sqrt()
|
|
w_norm = (w * w).sum().sqrt()
|
|
trust_ratio = ((w_norm > 0) * (g_norm > 0)).where(
|
|
self.eta * w_norm / (g_norm + self.weight_decay * w_norm + self.eps),
|
|
1.0)
|
|
|
|
scaled_lr = self.lr * trust_ratio
|
|
g = g + self.weight_decay * t.detach()
|
|
else:
|
|
scaled_lr = self.lr
|
|
|
|
g = g * scaled_lr
|
|
if self.momentum:
|
|
self.b[i].assign(self.momentum * self.b[i] + g)
|
|
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
|
t.assign(t.detach() - g)
|
|
self.realize(self.b)
|