mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
remove tfa dependency: use keras.optimizers.Lamb and tf.raw_ops for LARS (#13555)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -449,7 +449,7 @@ jobs:
|
||||
with:
|
||||
key: onnxoptl
|
||||
deps: testing
|
||||
pydeps: "tensorflow==2.15.1 tensorflow_addons"
|
||||
pydeps: "tensorflow==2.19"
|
||||
python-version: '3.11'
|
||||
opencl: 'true'
|
||||
- name: Test ONNX (CL)
|
||||
|
||||
6
test/external/external_test_optim.py
vendored
6
test/external/external_test_optim.py
vendored
@@ -2,7 +2,7 @@
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_addons as tfa
|
||||
from tensorflow.keras.optimizers import Lamb
|
||||
from tensorflow.python.ops import math_ops
|
||||
from extra.lr_scheduler import LRSchedulerGroup
|
||||
|
||||
@@ -88,6 +88,8 @@ def create_tiny_lars(params, lr, skip_list=False):
|
||||
if skip_list: return OptimizerGroup(LARS([params[0]], lr), SGD([params[1]], lr, classic=True, weight_decay=0., momentum=.9))
|
||||
return LARS(params, lr)
|
||||
def create_tf_lars(lr, skip_list=False): return LARSOptimizer(lr, skip_list=["W"] if skip_list else None)
|
||||
def create_tf_lamb(lr=0.001, b1=0.9, b2=0.999, eps=1e-7, weight_decay=0.0):
|
||||
return Lamb(learning_rate=float(lr), beta_1=b1, beta_2=b2, epsilon=eps, weight_decay=weight_decay)
|
||||
|
||||
def create_tiny_polylr(optim, initial_lr, end_lr, train_steps, warmup, power=2, skip_list=False):
|
||||
assert power == 2
|
||||
@@ -112,7 +114,7 @@ class ExternalTestOptim(unittest.TestCase):
|
||||
step_tf(tensorflow_optim, steps=steps, kwargs=opts, scheduler=tf_sched, schedopts=schedopts, do_optim=do_optim)):
|
||||
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
|
||||
|
||||
def _test_lamb(self, steps, opts, atol, rtol): self._test_optim(LAMB, tfa.optimizers.LAMB, steps, opts, atol, rtol)
|
||||
def _test_lamb(self, steps, opts, atol, rtol): self._test_optim(LAMB, create_tf_lamb, steps, opts, atol, rtol)
|
||||
def _test_lars(self, steps, opts, atol, rtol): self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol)
|
||||
def _test_lars_polylr(self, steps, opts, schedopts, atol, rtol, do_optim=True):
|
||||
self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol,
|
||||
|
||||
44
test/external/mlperf_resnet/lars_optimizer.py
vendored
44
test/external/mlperf_resnet/lars_optimizer.py
vendored
@@ -29,7 +29,6 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.training import training_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
|
||||
|
||||
@@ -147,20 +146,7 @@ class LARSOptimizer(optimizer_v2.OptimizerV2):
|
||||
return scaled_lr, grad
|
||||
|
||||
def _apply_dense(self, grad, var, apply_state=None):
|
||||
var_device, var_dtype = var.device, var.dtype.base_dtype
|
||||
coefficients = ((apply_state or {}).get((var_device, var_dtype))
|
||||
or self._fallback_apply_state(var_device, var_dtype))
|
||||
|
||||
scaled_lr, grad = self.compute_lr(grad, var, coefficients)
|
||||
mom = self.get_slot(var, "momentum")
|
||||
return training_ops.apply_momentum(
|
||||
var,
|
||||
mom,
|
||||
math_ops.cast(1.0, var.dtype.base_dtype),
|
||||
grad * scaled_lr,
|
||||
self.momentum,
|
||||
use_locking=False,
|
||||
use_nesterov=self.use_nesterov)
|
||||
return self._resource_apply_dense(grad, var, apply_state)
|
||||
|
||||
def _resource_apply_dense(self, grad, var, apply_state=None):
|
||||
var_device, var_dtype = var.device, var.dtype.base_dtype
|
||||
@@ -194,13 +180,13 @@ class LARSOptimizer(optimizer_v2.OptimizerV2):
|
||||
or self._fallback_apply_state(var_device, var_dtype))
|
||||
|
||||
mom = self.get_slot(var, "momentum")
|
||||
return training_ops.sparse_apply_momentum(
|
||||
var,
|
||||
mom,
|
||||
coefficients["learning_rate"],
|
||||
grad.values,
|
||||
grad.indices,
|
||||
self.momentum,
|
||||
return tf.raw_ops.SparseApplyMomentum(
|
||||
var=var,
|
||||
accum=mom,
|
||||
lr=coefficients["learning_rate"],
|
||||
grad=grad.values,
|
||||
indices=grad.indices,
|
||||
momentum=self.momentum,
|
||||
use_locking=False,
|
||||
use_nesterov=self.use_nesterov)
|
||||
|
||||
@@ -210,13 +196,13 @@ class LARSOptimizer(optimizer_v2.OptimizerV2):
|
||||
or self._fallback_apply_state(var_device, var_dtype))
|
||||
|
||||
mom = self.get_slot(var, "momentum")
|
||||
return training_ops.resource_sparse_apply_keras_momentum(
|
||||
var.handle,
|
||||
mom.handle,
|
||||
coefficients["learning_rate"],
|
||||
grad,
|
||||
indices,
|
||||
self.momentum,
|
||||
return tf.raw_ops.ResourceSparseApplyKerasMomentum(
|
||||
var=var.handle,
|
||||
accum=mom.handle,
|
||||
lr=coefficients["learning_rate"],
|
||||
grad=grad,
|
||||
indices=indices,
|
||||
momentum=self.momentum,
|
||||
use_locking=False,
|
||||
use_nesterov=self.use_nesterov)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user