Training loop for Stable Diffusion mlperf (#12315)

* add diff

* fix edit error

* match master

* point reference to specific commit

* simplify wandb logging

* remove lr test, dehardcode device

* increase stack size limit
This commit is contained in:
hooved
2025-10-03 02:45:38 -04:00
committed by GitHub
parent c7849ac593
commit 1e8945a28c
3 changed files with 189 additions and 1 deletions

View File

@@ -511,6 +511,33 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
# happens with BENCHMARK set
pass
# stable diffusion callbacks to match mlperf ref; declared here because they're pickled
def filter_dataset(sample:dict): return {k:v for k,v in sample.items() if k in {'npy', 'txt'}}
def collate(batch:list[dict]):
ret = {"npy": [], "txt": [], "__key__": []}
for sample in batch:
for k,v in sample.items():
ret[k].append(v)
return ret
def collate_fn(batch): return batch
# Reference (code): https://github.com/mlcommons/training/blob/2f4a93fb4888180755a8ef55f4b977ef8f60a89e/stable_diffusion/ldm/data/webdatasets.py, Line 55
# Reference (params): https://github.com/mlcommons/training/blob/ab4ae1ca718d7fe62c369710a316dff18768d04b/stable_diffusion/configs/train_01x08x08.yaml, Line 107
def batch_load_train_stable_diffusion(urls:str, BS:int):
import webdataset
dataset = webdataset.WebDataset(urls=urls, resampled=True, cache_size=-1, cache_dir=None)
dataset = dataset.shuffle(size=1000)
dataset = dataset.decode()
dataset = dataset.map(filter_dataset)
dataset = dataset.batched(BS, partial=False, collation_fn=collate)
dataset = webdataset.WebLoader(dataset, batch_size=None, shuffle=False, num_workers=1, persistent_workers=True, collate_fn=collate_fn)
for x in dataset:
assert isinstance(x, dict) and all(isinstance(k, str) for k in x.keys()) and all(isinstance(v, list) for v in x.values())
assert all(isinstance(moment_mean_logvar, np.ndarray) and moment_mean_logvar.shape==(1,8,64,64) for moment_mean_logvar in x["npy"])
assert all(isinstance(caption, str) for caption in x["txt"])
yield x
# llama3
class BinIdxDataset:

View File

@@ -1493,6 +1493,144 @@ def train_llama3():
safe_save(get_state_dict(model), fn)
break
def train_stable_diffusion():
from extra.models.unet import UNetModel
from examples.mlperf.dataloader import batch_load_train_stable_diffusion
from examples.mlperf.lr_schedulers import LambdaLR, LambdaLinearScheduler
from examples.mlperf.initializers import init_stable_diffusion
from examples.mlperf.helpers import get_training_state
import numpy as np
config = {}
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
seed = config["seed"] = getenv("SEED", 12345)
# ** hyperparameters **
BS = config["BS"] = getenv("BS", 1 * len(GPUS))
BASE_LR = config["LEARNING_RATE"] = getenv("LEARNING_RATE", 2.5e-7)
# https://github.com/mlcommons/training_policies/blob/cfa99da479b8d5931f7a3c67612d021dfb47510a/training_rules.adoc#benchmark_specific_rules
# "Checkpoint must be collected every 512,000 images. CEIL(512000 / global_batch_size) if 512000 is not divisible by GBS."
# NOTE: It's inferred that "steps" is the unit for the output of the CEIL formula, based on all other cases of CEIL in the rules
CKPT_STEP_INTERVAL = config["CKPT_STEP_INTERVAL"] = getenv("CKPT_STEP_INTERVAL", math.ceil(512_000 / BS))
CKPTDIR = config["CKPTDIR"] = Path(getenv("CKPTDIR", "./checkpoints"))
DATADIR = config["DATADIR"] = Path(getenv("DATADIR", "./datasets"))
UNET_CKPTDIR = config["UNET_CKPTDIR"] = Path(getenv("UNET_CKPTDIR", "./checkpoints"))
TOTAL_CKPTS = config["TOTAL_CKPTS"] = getenv("TOTAL_CKPTS", 0)
print(f"training on {GPUS}")
lr = BS * BASE_LR
print(f"BS={BS}, BASE_LR={BASE_LR}, lr={lr}")
print(f"CKPT_STEP_INTERVAL = {CKPT_STEP_INTERVAL}")
for x in GPUS: Device[x]
if (WANDB := getenv("WANDB", "")):
import wandb
wandb.init(config=config, project="MLPerf-Stable-Diffusion")
Tensor.manual_seed(seed) # seed for weight initialization
model, unet, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod = init_stable_diffusion("v2-mlperf-train", CKPTDIR / "sd" / "512-base-ema.ckpt", GPUS)
optimizer = AdamW(get_parameters(unet))
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
lr_scheduler = LambdaLR(optimizer, Tensor(lr, dtype=dtypes.float, device=optimizer.device), lambda_lr_callback)
@TinyJit
def train_step(mean:Tensor, logvar:Tensor, tokens:Tensor, unet:UNetModel, optimizer:LAMB, lr_scheduler:LambdaLR) -> Tensor:
optimizer.zero_grad()
timestep = Tensor.randint(BS, low=0, high=model.alphas_cumprod.shape[0], dtype=dtypes.int, device=GPUS[0])
latent_randn = Tensor.randn(*mean.shape, device=GPUS[0])
noise = Tensor.randn(*mean.shape, device=GPUS[0])
for t in (mean, logvar, tokens, timestep, latent_randn, noise):
t.shard_(GPUS, axis=0)
std = Tensor.exp(0.5 * logvar.clamp(-30.0, 20.0))
latent = (mean + std * latent_randn) * 0.18215
sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[timestep].reshape(timestep.shape[0], 1, 1, 1)
sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[timestep].reshape(timestep.shape[0], 1, 1, 1)
latent_with_noise = sqrt_alphas_cumprod_t * latent + sqrt_one_minus_alphas_cumprod_t * noise
v_true = sqrt_alphas_cumprod_t * noise - sqrt_one_minus_alphas_cumprod_t * latent
context = model.cond_stage_model.embed_tokens(tokens)
out = unet(latent_with_noise, timestep, context)
loss = ((out - v_true) ** 2).mean()
del mean, logvar, std, latent, noise, sqrt_alphas_cumprod_t, sqrt_one_minus_alphas_cumprod_t
del out, v_true, context, latent_randn, tokens, timestep
loss.backward()
optimizer.step()
lr_scheduler.step()
loss, out_lr = loss.detach().to("CPU"), optimizer.lr.to("CPU")
Tensor.realize(loss, out_lr)
return loss, out_lr
# checkpointing takes ~9 minutes without this, and ~1 minute with this
@TinyJit
def ckpt_to_cpu():
ckpt = get_training_state(unet, optimizer, lr_scheduler)
# move to CPU first so more GPU bufs aren't created (can trigger OOM)
for k,v in ckpt.items(): ckpt[k] = v.detach().to("CPU")
Tensor.realize(*[v for v in ckpt.values()])
for k,v in ckpt.items(): ckpt[k] = v.cast(v.dtype.base).contiguous()
Tensor.realize(*[v for v in ckpt.values()])
return ckpt
# training loop
dl = batch_load_train_stable_diffusion(f'{DATADIR}/laion-400m/webdataset-moments-filtered/{{00000..00831}}.tar', BS)
# for tests
saved_checkpoints = []
train_start_time = time.perf_counter()
t0 = t6 = time.perf_counter()
for i, batch in enumerate(dl, start=1):
loop_time = time.perf_counter() - t0
t0 = time.perf_counter()
dl_time = t0 - t6
GlobalCounters.reset()
mean, logvar = np.split(np.concatenate(batch["npy"], axis=0), 2, axis=1)
mean, logvar = Tensor(mean, dtype=dtypes.float32, device="CPU"), Tensor(logvar, dtype=dtypes.float32, device="CPU")
tokens = []
for text in batch['txt']: tokens += model.cond_stage_model.tokenizer.encode(text, pad_with_zeros=True)
tokens = Tensor(tokens, dtype=dtypes.int32, device="CPU").reshape(-1, 77)
t1 = time.perf_counter()
loss, lr = train_step(mean, logvar, tokens, unet, optimizer, lr_scheduler)
loss_item, lr_item = loss.item(), lr.item()
t2 = time.perf_counter()
if i == 3:
for _ in range(3): ckpt_to_cpu() # do this at the beginning of run to prevent OOM surprises when checkpointing
print("BEAM COMPLETE", flush=True) # allows wrapper script to detect BEAM search completion and retry if it failed
total_train_time = time.perf_counter() - train_start_time
if WANDB:
wandb.log({"train/loss": loss_item, "train/lr": lr_item, "train/loop_time_prev": loop_time, "train/dl_time": dl_time, "train/step": i,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (t2-t1), "train/input_prep_time": t1-t0,
"train/train_step_time": t2-t1, "train/total_time": total_train_time})
if i == 1 and wandb.run is not None:
with open(f"{UNET_CKPTDIR}/wandb_run_id_{wandb.run.id}", "w") as f:
f.write(f"wandb.run.id = {wandb.run.id}")
if i % CKPT_STEP_INTERVAL == 0:
# https://github.com/mlcommons/training_policies/blob/cfa99da479b8d5931f7a3c67612d021dfb47510a/training_rules.adoc#benchmark_specific_rules
# "evaluation is done offline, the time is not counted towards the submission time."
fn = f"{UNET_CKPTDIR}/{i}.safetensors"
print(f"saving unet checkpoint at {fn}")
saved_checkpoints.append(fn)
safe_save({k.replace("model.", ""):v for k,v in ckpt_to_cpu().items() if k.startswith("model.")}, fn)
if TOTAL_CKPTS and i == TOTAL_CKPTS * CKPT_STEP_INTERVAL:
print(f"ending run after {i} steps ({TOTAL_CKPTS} checkpoints collected)")
return saved_checkpoints
t3 = time.perf_counter()
print(f"""step {i}: {GlobalCounters.global_ops * 1e-9 / (t2-t1):9.2f} GFLOPS, mem_used: {GlobalCounters.mem_used / 1e9:.2f} GB,
loop_time_prev: {loop_time:.2f}, dl_time: {dl_time:.2f}, input_prep_time: {t1-t0:.2f}, train_step_time: {t2-t1:.2f},
t3-t2: {t3-t2:.4f}, loss:{loss_item:.5f}, lr:{lr_item:.3e}, total_train_time:{total_train_time:.2f}
""")
t6 = time.perf_counter()
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
@@ -1501,7 +1639,7 @@ if __name__ == "__main__":
else: bench_log_manager = contextlib.nullcontext()
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")

View File

@@ -0,0 +1,23 @@
import unittest, os
from tempfile import TemporaryDirectory
from tinygrad import Tensor
from tinygrad.helpers import getenv
from examples.mlperf.model_train import train_stable_diffusion
class TestTrain(unittest.TestCase):
def test_train_to_ckpt(self):
# train for num_steps, save checkpoint, and stop training
num_steps = 42
os.environ.update({"MODEL": "stable_diffusion", "TOTAL_CKPTS": "1", "CKPT_STEP_INTERVAL": str(num_steps), "GPUS": "8", "BS": "304"})
# NOTE: update these based on where data/checkpoints are on your system
if not getenv("DATADIR", ""): os.environ["DATADIR"] = "/raid/datasets/stable_diffusion"
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
with TemporaryDirectory(prefix="test-train") as tmp:
os.environ["UNET_CKPTDIR"] = tmp
with Tensor.train():
saved_ckpts = train_stable_diffusion()
expected_ckpt = f"{tmp}/{num_steps}.safetensors"
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
if __name__=="__main__":
unittest.main()