diff --git a/examples/mlperf/lr_schedulers.py b/examples/mlperf/lr_schedulers.py index e339ae2d51..b8e9ffa8e5 100644 --- a/examples/mlperf/lr_schedulers.py +++ b/examples/mlperf/lr_schedulers.py @@ -1,8 +1,9 @@ import math -from tinygrad import dtypes +from tinygrad import dtypes, Tensor from tinygrad.nn.optim import Optimizer from extra.lr_scheduler import LR_Scheduler +from typing import Callable # https://github.com/mlcommons/training/blob/e237206991d10449d9675d95606459a3cb6c21ad/image_classification/tensorflow2/lars_util.py class PolynomialDecayWithWarmup(LR_Scheduler): @@ -36,4 +37,24 @@ class CosineAnnealingLRWithWarmup(LR_Scheduler): def get_lr(self): warmup_lr = ((self.epoch_counter+1) / self.warmup_steps) * self.base_lr decay_lr = self.end_lr + 0.5 * (self.base_lr-self.end_lr) * (1 + (((self.epoch_counter+1-self.warmup_steps)/self.decay_steps) * math.pi).cos()) - return (self.epoch_counter < self.warmup_steps).where(warmup_lr, decay_lr).cast(self.optimizer.lr.dtype) \ No newline at end of file + return (self.epoch_counter < self.warmup_steps).where(warmup_lr, decay_lr).cast(self.optimizer.lr.dtype) + +# Reference: https://github.com/mlcommons/training/blob/64b14a9abc74e08779a175abca7d291f8c957632/stable_diffusion/ldm/lr_scheduler.py, Lines 36-97 +class LambdaLinearScheduler: + def __init__(self, warm_up_steps:int, f_min:float, f_max:float, f_start:float, cycle_lengths:int): + self.lr_warm_up_steps, self.f_min, self.f_max, self.f_start, self.cycle_lengths = warm_up_steps, f_min, f_max, f_start, cycle_lengths + + def schedule(self, n:Tensor) -> Tensor: + warm_up = (n < self.lr_warm_up_steps) + f_warm_up = (self.f_max - self.f_start) / self.lr_warm_up_steps * n + self.f_start + return warm_up.where(f_warm_up, self.f_min + (self.f_max - self.f_min) * (self.cycle_lengths - n) / (self.cycle_lengths)) + +# based on torch.optim.lr_scheduler.LambdaLR +class LambdaLR(LR_Scheduler): + def __init__(self, optimizer:Optimizer, base_lr:Tensor, lr_lambda:Callable): + super().__init__(optimizer) + self.base_lr, self.lr_lambda = base_lr, lr_lambda + self.step() + + def get_lr(self): + return self.base_lr * self.lr_lambda(self.epoch_counter - 1) \ No newline at end of file diff --git a/test/external/external_test_optim.py b/test/external/external_test_optim.py index 622a0645e2..014601bae7 100644 --- a/test/external/external_test_optim.py +++ b/test/external/external_test_optim.py @@ -11,7 +11,7 @@ from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer -from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup +from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf np.random.seed(1337) @@ -192,5 +192,27 @@ class TestCosineAnnealingLRWithWarmup(unittest.TestCase): def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20) def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100) +class TestLambdaLRLinearWarmup(unittest.TestCase): + def test_linear_lr_warmup(self): + BS, BASE_LR = 304, 2.5e-7 + lr = BS * BASE_LR + # Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched. + optimizer = AdamW([Tensor([1.])]) + lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule + lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback) + lrs = {} + + # with above settings, optimizer.lr should warm up to lr over 1000 steps linearly + for i in range(1200): + lr_scheduler.step() + if i in {0, 499, 998, 999, 1000, 1199}: + lrs[i] = optimizer.lr.item() + + np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11) + np.testing.assert_equal(lrs[999], lrs[1000]) + np.testing.assert_equal(lrs[999], lrs[1199]) + np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1) + np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5) + if __name__ == '__main__': unittest.main()