Files
tinygrad/examples/mlperf/optimizers.py
David Hou 0afaf70d57 lars optimizer + tests (#3631)
* lars optimizer + tests

* fix skip list!

* use id to compare in skip list

* go back to using set

* Tensor(bool) * Tensor(bool) is and

* don't lint external/mlperf_resnet

* whitespace

* add external_test_optim to opencl tests

* give mlperf task a name

* mlperf under onnx

* remove track_gnorm

* contiguous instead of realize

* assert momentum and weight decay positive

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-03-06 18:11:01 -05:00

39 lines
1.4 KiB
Python

from typing import List, Set
from tinygrad import Tensor
from tinygrad.nn.optim import Optimizer
# https://github.com/mlcommons/training/blob/master/image_classification/tensorflow2/lars_optimizer.py
class LARS(Optimizer):
def __init__(self, params: List[Tensor], lr, momentum=0.9, weight_decay=1e-4, eta=0.001, eps=0.0, skip_list=None, nesterov=False):
super().__init__(params, lr)
assert momentum >= 0.0 and weight_decay >= 0.0
self.momentum, self.weight_decay, self.eta, self.eps, self.nesterov = momentum, weight_decay, eta, eps, nesterov
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
self.skip_list = set(skip_list or [])
def step(self):
for i, t in enumerate(self.params):
assert t.grad is not None
g = t.grad.contiguous()
w = t.detach()
if t not in self.skip_list:
g_norm = (g * g).sum().sqrt()
w_norm = (w * w).sum().sqrt()
trust_ratio = ((w_norm > 0) * (g_norm > 0)).where(
self.eta * w_norm / (g_norm + self.weight_decay * w_norm + self.eps),
1.0)
scaled_lr = self.lr * trust_ratio
g = g + self.weight_decay * t.detach()
else:
scaled_lr = self.lr
g = g * scaled_lr
if self.momentum:
self.b[i].assign(self.momentum * self.b[i] + g)
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
t.assign(t.detach() - g)
self.realize(self.b)