LR scheduler for Stable Diffusion mlperf training (#12201)

* add lr scheduler for stable diffusion training

* add lr scheduler test

* rerun ci

* rerun CI

* use np for testing

* move test to CI path

* remove unneeded copy
This commit is contained in:
hooved
2025-09-30 21:21:08 -04:00
committed by GitHub
parent 9ef319f349
commit 969a1b35ca
2 changed files with 46 additions and 3 deletions

View File

@@ -1,8 +1,9 @@
import math
from tinygrad import dtypes
from tinygrad import dtypes, Tensor
from tinygrad.nn.optim import Optimizer
from extra.lr_scheduler import LR_Scheduler
from typing import Callable
# https://github.com/mlcommons/training/blob/e237206991d10449d9675d95606459a3cb6c21ad/image_classification/tensorflow2/lars_util.py
class PolynomialDecayWithWarmup(LR_Scheduler):
@@ -36,4 +37,24 @@ class CosineAnnealingLRWithWarmup(LR_Scheduler):
def get_lr(self):
warmup_lr = ((self.epoch_counter+1) / self.warmup_steps) * self.base_lr
decay_lr = self.end_lr + 0.5 * (self.base_lr-self.end_lr) * (1 + (((self.epoch_counter+1-self.warmup_steps)/self.decay_steps) * math.pi).cos())
return (self.epoch_counter < self.warmup_steps).where(warmup_lr, decay_lr).cast(self.optimizer.lr.dtype)
return (self.epoch_counter < self.warmup_steps).where(warmup_lr, decay_lr).cast(self.optimizer.lr.dtype)
# Reference: https://github.com/mlcommons/training/blob/64b14a9abc74e08779a175abca7d291f8c957632/stable_diffusion/ldm/lr_scheduler.py, Lines 36-97
class LambdaLinearScheduler:
def __init__(self, warm_up_steps:int, f_min:float, f_max:float, f_start:float, cycle_lengths:int):
self.lr_warm_up_steps, self.f_min, self.f_max, self.f_start, self.cycle_lengths = warm_up_steps, f_min, f_max, f_start, cycle_lengths
def schedule(self, n:Tensor) -> Tensor:
warm_up = (n < self.lr_warm_up_steps)
f_warm_up = (self.f_max - self.f_start) / self.lr_warm_up_steps * n + self.f_start
return warm_up.where(f_warm_up, self.f_min + (self.f_max - self.f_min) * (self.cycle_lengths - n) / (self.cycle_lengths))
# based on torch.optim.lr_scheduler.LambdaLR
class LambdaLR(LR_Scheduler):
def __init__(self, optimizer:Optimizer, base_lr:Tensor, lr_lambda:Callable):
super().__init__(optimizer)
self.base_lr, self.lr_lambda = base_lr, lr_lambda
self.step()
def get_lr(self):
return self.base_lr * self.lr_lambda(self.epoch_counter - 1)

View File

@@ -11,7 +11,7 @@ from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
np.random.seed(1337)
@@ -192,5 +192,27 @@ class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
class TestLambdaLRLinearWarmup(unittest.TestCase):
def test_linear_lr_warmup(self):
BS, BASE_LR = 304, 2.5e-7
lr = BS * BASE_LR
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
optimizer = AdamW([Tensor([1.])])
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
lrs = {}
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
for i in range(1200):
lr_scheduler.step()
if i in {0, 499, 998, 999, 1000, 1199}:
lrs[i] = optimizer.lr.item()
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
np.testing.assert_equal(lrs[999], lrs[1000])
np.testing.assert_equal(lrs[999], lrs[1199])
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
if __name__ == '__main__':
unittest.main()