remove tfa dependency: use keras.optimizers.Lamb and tf.raw_ops for LARS (#13555)

This commit is contained in:
Douglas Nyberg
2025-12-03 17:48:27 -05:00
committed by GitHub
parent a4c4e48385
commit f5abd38132
3 changed files with 20 additions and 32 deletions

View File

@@ -449,7 +449,7 @@ jobs:
with:
key: onnxoptl
deps: testing
pydeps: "tensorflow==2.15.1 tensorflow_addons"
pydeps: "tensorflow==2.19"
python-version: '3.11'
opencl: 'true'
- name: Test ONNX (CL)

View File

@@ -2,7 +2,7 @@
import unittest, math
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.optimizers import Lamb
from tensorflow.python.ops import math_ops
from extra.lr_scheduler import LRSchedulerGroup
@@ -88,6 +88,8 @@ def create_tiny_lars(params, lr, skip_list=False):
if skip_list: return OptimizerGroup(LARS([params[0]], lr), SGD([params[1]], lr, classic=True, weight_decay=0., momentum=.9))
return LARS(params, lr)
def create_tf_lars(lr, skip_list=False): return LARSOptimizer(lr, skip_list=["W"] if skip_list else None)
def create_tf_lamb(lr=0.001, b1=0.9, b2=0.999, eps=1e-7, weight_decay=0.0):
return Lamb(learning_rate=float(lr), beta_1=b1, beta_2=b2, epsilon=eps, weight_decay=weight_decay)
def create_tiny_polylr(optim, initial_lr, end_lr, train_steps, warmup, power=2, skip_list=False):
assert power == 2
@@ -112,7 +114,7 @@ class ExternalTestOptim(unittest.TestCase):
step_tf(tensorflow_optim, steps=steps, kwargs=opts, scheduler=tf_sched, schedopts=schedopts, do_optim=do_optim)):
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
def _test_lamb(self, steps, opts, atol, rtol): self._test_optim(LAMB, tfa.optimizers.LAMB, steps, opts, atol, rtol)
def _test_lamb(self, steps, opts, atol, rtol): self._test_optim(LAMB, create_tf_lamb, steps, opts, atol, rtol)
def _test_lars(self, steps, opts, atol, rtol): self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol)
def _test_lars_polylr(self, steps, opts, schedopts, atol, rtol, do_optim=True):
self._test_optim(create_tiny_lars, create_tf_lars, steps, opts, atol, rtol,

View File

@@ -29,7 +29,6 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import training_ops
from tensorflow.python.ops import state_ops
@@ -147,20 +146,7 @@ class LARSOptimizer(optimizer_v2.OptimizerV2):
return scaled_lr, grad
def _apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype))
or self._fallback_apply_state(var_device, var_dtype))
scaled_lr, grad = self.compute_lr(grad, var, coefficients)
mom = self.get_slot(var, "momentum")
return training_ops.apply_momentum(
var,
mom,
math_ops.cast(1.0, var.dtype.base_dtype),
grad * scaled_lr,
self.momentum,
use_locking=False,
use_nesterov=self.use_nesterov)
return self._resource_apply_dense(grad, var, apply_state)
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
@@ -194,13 +180,13 @@ class LARSOptimizer(optimizer_v2.OptimizerV2):
or self._fallback_apply_state(var_device, var_dtype))
mom = self.get_slot(var, "momentum")
return training_ops.sparse_apply_momentum(
var,
mom,
coefficients["learning_rate"],
grad.values,
grad.indices,
self.momentum,
return tf.raw_ops.SparseApplyMomentum(
var=var,
accum=mom,
lr=coefficients["learning_rate"],
grad=grad.values,
indices=grad.indices,
momentum=self.momentum,
use_locking=False,
use_nesterov=self.use_nesterov)
@@ -210,13 +196,13 @@ class LARSOptimizer(optimizer_v2.OptimizerV2):
or self._fallback_apply_state(var_device, var_dtype))
mom = self.get_slot(var, "momentum")
return training_ops.resource_sparse_apply_keras_momentum(
var.handle,
mom.handle,
coefficients["learning_rate"],
grad,
indices,
self.momentum,
return tf.raw_ops.ResourceSparseApplyKerasMomentum(
var=var.handle,
accum=mom.handle,
lr=coefficients["learning_rate"],
grad=grad,
indices=indices,
momentum=self.momentum,
use_locking=False,
use_nesterov=self.use_nesterov)