From 1e8945a28c9a6c53c7a6560035e4e50f46ffed8d Mon Sep 17 00:00:00 2001 From: hooved <172129504+hooved@users.noreply.github.com> Date: Fri, 3 Oct 2025 02:45:38 -0400 Subject: [PATCH] Training loop for Stable Diffusion mlperf (#12315) * add diff * fix edit error * match master * point reference to specific commit * simplify wandb logging * remove lr test, dehardcode device * increase stack size limit --- examples/mlperf/dataloader.py | 27 ++++ examples/mlperf/model_train.py | 140 +++++++++++++++++- .../external_test_train.py | 23 +++ 3 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 test/external/mlperf_stable_diffusion/external_test_train.py diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 09fb191539..67eae92ce7 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -511,6 +511,33 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh # happens with BENCHMARK set pass +# stable diffusion callbacks to match mlperf ref; declared here because they're pickled +def filter_dataset(sample:dict): return {k:v for k,v in sample.items() if k in {'npy', 'txt'}} +def collate(batch:list[dict]): + ret = {"npy": [], "txt": [], "__key__": []} + for sample in batch: + for k,v in sample.items(): + ret[k].append(v) + return ret +def collate_fn(batch): return batch + +# Reference (code): https://github.com/mlcommons/training/blob/2f4a93fb4888180755a8ef55f4b977ef8f60a89e/stable_diffusion/ldm/data/webdatasets.py, Line 55 +# Reference (params): https://github.com/mlcommons/training/blob/ab4ae1ca718d7fe62c369710a316dff18768d04b/stable_diffusion/configs/train_01x08x08.yaml, Line 107 +def batch_load_train_stable_diffusion(urls:str, BS:int): + import webdataset + dataset = webdataset.WebDataset(urls=urls, resampled=True, cache_size=-1, cache_dir=None) + dataset = dataset.shuffle(size=1000) + dataset = dataset.decode() + dataset = dataset.map(filter_dataset) + dataset = dataset.batched(BS, partial=False, collation_fn=collate) + dataset = webdataset.WebLoader(dataset, batch_size=None, shuffle=False, num_workers=1, persistent_workers=True, collate_fn=collate_fn) + + for x in dataset: + assert isinstance(x, dict) and all(isinstance(k, str) for k in x.keys()) and all(isinstance(v, list) for v in x.values()) + assert all(isinstance(moment_mean_logvar, np.ndarray) and moment_mean_logvar.shape==(1,8,64,64) for moment_mean_logvar in x["npy"]) + assert all(isinstance(caption, str) for caption in x["txt"]) + yield x + # llama3 class BinIdxDataset: diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index c2e961e9cf..4b333918e3 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1493,6 +1493,144 @@ def train_llama3(): safe_save(get_state_dict(model), fn) break +def train_stable_diffusion(): + from extra.models.unet import UNetModel + from examples.mlperf.dataloader import batch_load_train_stable_diffusion + from examples.mlperf.lr_schedulers import LambdaLR, LambdaLinearScheduler + from examples.mlperf.initializers import init_stable_diffusion + from examples.mlperf.helpers import get_training_state + import numpy as np + + config = {} + GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))] + seed = config["seed"] = getenv("SEED", 12345) + # ** hyperparameters ** + BS = config["BS"] = getenv("BS", 1 * len(GPUS)) + BASE_LR = config["LEARNING_RATE"] = getenv("LEARNING_RATE", 2.5e-7) + # https://github.com/mlcommons/training_policies/blob/cfa99da479b8d5931f7a3c67612d021dfb47510a/training_rules.adoc#benchmark_specific_rules + # "Checkpoint must be collected every 512,000 images. CEIL(512000 / global_batch_size) if 512000 is not divisible by GBS." + # NOTE: It's inferred that "steps" is the unit for the output of the CEIL formula, based on all other cases of CEIL in the rules + CKPT_STEP_INTERVAL = config["CKPT_STEP_INTERVAL"] = getenv("CKPT_STEP_INTERVAL", math.ceil(512_000 / BS)) + CKPTDIR = config["CKPTDIR"] = Path(getenv("CKPTDIR", "./checkpoints")) + DATADIR = config["DATADIR"] = Path(getenv("DATADIR", "./datasets")) + UNET_CKPTDIR = config["UNET_CKPTDIR"] = Path(getenv("UNET_CKPTDIR", "./checkpoints")) + TOTAL_CKPTS = config["TOTAL_CKPTS"] = getenv("TOTAL_CKPTS", 0) + + print(f"training on {GPUS}") + lr = BS * BASE_LR + print(f"BS={BS}, BASE_LR={BASE_LR}, lr={lr}") + print(f"CKPT_STEP_INTERVAL = {CKPT_STEP_INTERVAL}") + for x in GPUS: Device[x] + if (WANDB := getenv("WANDB", "")): + import wandb + wandb.init(config=config, project="MLPerf-Stable-Diffusion") + + Tensor.manual_seed(seed) # seed for weight initialization + model, unet, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod = init_stable_diffusion("v2-mlperf-train", CKPTDIR / "sd" / "512-base-ema.ckpt", GPUS) + + optimizer = AdamW(get_parameters(unet)) + lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule + lr_scheduler = LambdaLR(optimizer, Tensor(lr, dtype=dtypes.float, device=optimizer.device), lambda_lr_callback) + + @TinyJit + def train_step(mean:Tensor, logvar:Tensor, tokens:Tensor, unet:UNetModel, optimizer:LAMB, lr_scheduler:LambdaLR) -> Tensor: + optimizer.zero_grad() + + timestep = Tensor.randint(BS, low=0, high=model.alphas_cumprod.shape[0], dtype=dtypes.int, device=GPUS[0]) + latent_randn = Tensor.randn(*mean.shape, device=GPUS[0]) + noise = Tensor.randn(*mean.shape, device=GPUS[0]) + for t in (mean, logvar, tokens, timestep, latent_randn, noise): + t.shard_(GPUS, axis=0) + + std = Tensor.exp(0.5 * logvar.clamp(-30.0, 20.0)) + latent = (mean + std * latent_randn) * 0.18215 + + sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[timestep].reshape(timestep.shape[0], 1, 1, 1) + sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[timestep].reshape(timestep.shape[0], 1, 1, 1) + latent_with_noise = sqrt_alphas_cumprod_t * latent + sqrt_one_minus_alphas_cumprod_t * noise + v_true = sqrt_alphas_cumprod_t * noise - sqrt_one_minus_alphas_cumprod_t * latent + + context = model.cond_stage_model.embed_tokens(tokens) + + out = unet(latent_with_noise, timestep, context) + loss = ((out - v_true) ** 2).mean() + del mean, logvar, std, latent, noise, sqrt_alphas_cumprod_t, sqrt_one_minus_alphas_cumprod_t + del out, v_true, context, latent_randn, tokens, timestep + loss.backward() + + optimizer.step() + lr_scheduler.step() + loss, out_lr = loss.detach().to("CPU"), optimizer.lr.to("CPU") + Tensor.realize(loss, out_lr) + return loss, out_lr + + # checkpointing takes ~9 minutes without this, and ~1 minute with this + @TinyJit + def ckpt_to_cpu(): + ckpt = get_training_state(unet, optimizer, lr_scheduler) + # move to CPU first so more GPU bufs aren't created (can trigger OOM) + for k,v in ckpt.items(): ckpt[k] = v.detach().to("CPU") + Tensor.realize(*[v for v in ckpt.values()]) + for k,v in ckpt.items(): ckpt[k] = v.cast(v.dtype.base).contiguous() + Tensor.realize(*[v for v in ckpt.values()]) + return ckpt + + # training loop + dl = batch_load_train_stable_diffusion(f'{DATADIR}/laion-400m/webdataset-moments-filtered/{{00000..00831}}.tar', BS) + # for tests + saved_checkpoints = [] + + train_start_time = time.perf_counter() + t0 = t6 = time.perf_counter() + for i, batch in enumerate(dl, start=1): + loop_time = time.perf_counter() - t0 + t0 = time.perf_counter() + dl_time = t0 - t6 + GlobalCounters.reset() + + mean, logvar = np.split(np.concatenate(batch["npy"], axis=0), 2, axis=1) + mean, logvar = Tensor(mean, dtype=dtypes.float32, device="CPU"), Tensor(logvar, dtype=dtypes.float32, device="CPU") + tokens = [] + for text in batch['txt']: tokens += model.cond_stage_model.tokenizer.encode(text, pad_with_zeros=True) + tokens = Tensor(tokens, dtype=dtypes.int32, device="CPU").reshape(-1, 77) + + t1 = time.perf_counter() + loss, lr = train_step(mean, logvar, tokens, unet, optimizer, lr_scheduler) + loss_item, lr_item = loss.item(), lr.item() + t2 = time.perf_counter() + + if i == 3: + for _ in range(3): ckpt_to_cpu() # do this at the beginning of run to prevent OOM surprises when checkpointing + print("BEAM COMPLETE", flush=True) # allows wrapper script to detect BEAM search completion and retry if it failed + + total_train_time = time.perf_counter() - train_start_time + if WANDB: + wandb.log({"train/loss": loss_item, "train/lr": lr_item, "train/loop_time_prev": loop_time, "train/dl_time": dl_time, "train/step": i, + "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (t2-t1), "train/input_prep_time": t1-t0, + "train/train_step_time": t2-t1, "train/total_time": total_train_time}) + + if i == 1 and wandb.run is not None: + with open(f"{UNET_CKPTDIR}/wandb_run_id_{wandb.run.id}", "w") as f: + f.write(f"wandb.run.id = {wandb.run.id}") + + if i % CKPT_STEP_INTERVAL == 0: + # https://github.com/mlcommons/training_policies/blob/cfa99da479b8d5931f7a3c67612d021dfb47510a/training_rules.adoc#benchmark_specific_rules + # "evaluation is done offline, the time is not counted towards the submission time." + fn = f"{UNET_CKPTDIR}/{i}.safetensors" + print(f"saving unet checkpoint at {fn}") + saved_checkpoints.append(fn) + safe_save({k.replace("model.", ""):v for k,v in ckpt_to_cpu().items() if k.startswith("model.")}, fn) + if TOTAL_CKPTS and i == TOTAL_CKPTS * CKPT_STEP_INTERVAL: + print(f"ending run after {i} steps ({TOTAL_CKPTS} checkpoints collected)") + return saved_checkpoints + + t3 = time.perf_counter() + print(f"""step {i}: {GlobalCounters.global_ops * 1e-9 / (t2-t1):9.2f} GFLOPS, mem_used: {GlobalCounters.mem_used / 1e9:.2f} GB, + loop_time_prev: {loop_time:.2f}, dl_time: {dl_time:.2f}, input_prep_time: {t1-t0:.2f}, train_step_time: {t2-t1:.2f}, + t3-t2: {t3-t2:.4f}, loss:{loss_item:.5f}, lr:{lr_item:.3e}, total_train_time:{total_train_time:.2f} + """) + t6 = time.perf_counter() + if __name__ == "__main__": multiprocessing.set_start_method('spawn') @@ -1501,7 +1639,7 @@ if __name__ == "__main__": else: bench_log_manager = contextlib.nullcontext() with Tensor.train(): - for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","): + for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","): nm = f"train_{m}" if nm in globals(): print(f"training {m}") diff --git a/test/external/mlperf_stable_diffusion/external_test_train.py b/test/external/mlperf_stable_diffusion/external_test_train.py new file mode 100644 index 0000000000..009e442da1 --- /dev/null +++ b/test/external/mlperf_stable_diffusion/external_test_train.py @@ -0,0 +1,23 @@ +import unittest, os +from tempfile import TemporaryDirectory +from tinygrad import Tensor +from tinygrad.helpers import getenv +from examples.mlperf.model_train import train_stable_diffusion + +class TestTrain(unittest.TestCase): + def test_train_to_ckpt(self): + # train for num_steps, save checkpoint, and stop training + num_steps = 42 + os.environ.update({"MODEL": "stable_diffusion", "TOTAL_CKPTS": "1", "CKPT_STEP_INTERVAL": str(num_steps), "GPUS": "8", "BS": "304"}) + # NOTE: update these based on where data/checkpoints are on your system + if not getenv("DATADIR", ""): os.environ["DATADIR"] = "/raid/datasets/stable_diffusion" + if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion" + with TemporaryDirectory(prefix="test-train") as tmp: + os.environ["UNET_CKPTDIR"] = tmp + with Tensor.train(): + saved_ckpts = train_stable_diffusion() + expected_ckpt = f"{tmp}/{num_steps}.safetensors" + assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt + +if __name__=="__main__": + unittest.main() \ No newline at end of file