mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
LR scheduler for Stable Diffusion mlperf training (#12201)
* add lr scheduler for stable diffusion training * add lr scheduler test * rerun ci * rerun CI * use np for testing * move test to CI path * remove unneeded copy
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.nn.optim import Optimizer
|
||||
|
||||
from extra.lr_scheduler import LR_Scheduler
|
||||
from typing import Callable
|
||||
|
||||
# https://github.com/mlcommons/training/blob/e237206991d10449d9675d95606459a3cb6c21ad/image_classification/tensorflow2/lars_util.py
|
||||
class PolynomialDecayWithWarmup(LR_Scheduler):
|
||||
@@ -36,4 +37,24 @@ class CosineAnnealingLRWithWarmup(LR_Scheduler):
|
||||
def get_lr(self):
|
||||
warmup_lr = ((self.epoch_counter+1) / self.warmup_steps) * self.base_lr
|
||||
decay_lr = self.end_lr + 0.5 * (self.base_lr-self.end_lr) * (1 + (((self.epoch_counter+1-self.warmup_steps)/self.decay_steps) * math.pi).cos())
|
||||
return (self.epoch_counter < self.warmup_steps).where(warmup_lr, decay_lr).cast(self.optimizer.lr.dtype)
|
||||
return (self.epoch_counter < self.warmup_steps).where(warmup_lr, decay_lr).cast(self.optimizer.lr.dtype)
|
||||
|
||||
# Reference: https://github.com/mlcommons/training/blob/64b14a9abc74e08779a175abca7d291f8c957632/stable_diffusion/ldm/lr_scheduler.py, Lines 36-97
|
||||
class LambdaLinearScheduler:
|
||||
def __init__(self, warm_up_steps:int, f_min:float, f_max:float, f_start:float, cycle_lengths:int):
|
||||
self.lr_warm_up_steps, self.f_min, self.f_max, self.f_start, self.cycle_lengths = warm_up_steps, f_min, f_max, f_start, cycle_lengths
|
||||
|
||||
def schedule(self, n:Tensor) -> Tensor:
|
||||
warm_up = (n < self.lr_warm_up_steps)
|
||||
f_warm_up = (self.f_max - self.f_start) / self.lr_warm_up_steps * n + self.f_start
|
||||
return warm_up.where(f_warm_up, self.f_min + (self.f_max - self.f_min) * (self.cycle_lengths - n) / (self.cycle_lengths))
|
||||
|
||||
# based on torch.optim.lr_scheduler.LambdaLR
|
||||
class LambdaLR(LR_Scheduler):
|
||||
def __init__(self, optimizer:Optimizer, base_lr:Tensor, lr_lambda:Callable):
|
||||
super().__init__(optimizer)
|
||||
self.base_lr, self.lr_lambda = base_lr, lr_lambda
|
||||
self.step()
|
||||
|
||||
def get_lr(self):
|
||||
return self.base_lr * self.lr_lambda(self.epoch_counter - 1)
|
||||
24
test/external/external_test_optim.py
vendored
24
test/external/external_test_optim.py
vendored
@@ -11,7 +11,7 @@ from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, AdamW
|
||||
|
||||
from test.external.mlperf_resnet.lars_optimizer import LARSOptimizer
|
||||
|
||||
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup
|
||||
from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup, CosineAnnealingLRWithWarmup, LambdaLR, LambdaLinearScheduler
|
||||
from test.external.mlperf_resnet.lars_util import PolynomialDecayWithWarmup as PolynomialDecayWithWarmup_tf
|
||||
|
||||
np.random.seed(1337)
|
||||
@@ -192,5 +192,27 @@ class TestCosineAnnealingLRWithWarmup(unittest.TestCase):
|
||||
def test_lr_1(self): self._test_lr(3e-4, 8e-5, 10, 20)
|
||||
def test_lr_llama3(self): self._test_lr(8e-5, 8e-7, 20, 100)
|
||||
|
||||
class TestLambdaLRLinearWarmup(unittest.TestCase):
|
||||
def test_linear_lr_warmup(self):
|
||||
BS, BASE_LR = 304, 2.5e-7
|
||||
lr = BS * BASE_LR
|
||||
# Use a dummy Tensor parameter for optimizer because the lr_scheduler only needs the optimizer's device and lr, the params aren't touched.
|
||||
optimizer = AdamW([Tensor([1.])])
|
||||
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
|
||||
lr_scheduler = LambdaLR(optimizer, Tensor(lr, device=optimizer.device), lambda_lr_callback)
|
||||
lrs = {}
|
||||
|
||||
# with above settings, optimizer.lr should warm up to lr over 1000 steps linearly
|
||||
for i in range(1200):
|
||||
lr_scheduler.step()
|
||||
if i in {0, 499, 998, 999, 1000, 1199}:
|
||||
lrs[i] = optimizer.lr.item()
|
||||
|
||||
np.testing.assert_allclose(lr, lrs[999], rtol=0, atol=1e-11)
|
||||
np.testing.assert_equal(lrs[999], lrs[1000])
|
||||
np.testing.assert_equal(lrs[999], lrs[1199])
|
||||
np.testing.assert_allclose(lrs[999] / lrs[0], 1000, rtol=0, atol=1)
|
||||
np.testing.assert_allclose(lrs[999] / lrs[499], 2, rtol=0, atol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user