mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Training loop for Stable Diffusion mlperf (#12315)
* add diff * fix edit error * match master * point reference to specific commit * simplify wandb logging * remove lr test, dehardcode device * increase stack size limit
This commit is contained in:
@@ -511,6 +511,33 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
# happens with BENCHMARK set
|
||||
pass
|
||||
|
||||
# stable diffusion callbacks to match mlperf ref; declared here because they're pickled
|
||||
def filter_dataset(sample:dict): return {k:v for k,v in sample.items() if k in {'npy', 'txt'}}
|
||||
def collate(batch:list[dict]):
|
||||
ret = {"npy": [], "txt": [], "__key__": []}
|
||||
for sample in batch:
|
||||
for k,v in sample.items():
|
||||
ret[k].append(v)
|
||||
return ret
|
||||
def collate_fn(batch): return batch
|
||||
|
||||
# Reference (code): https://github.com/mlcommons/training/blob/2f4a93fb4888180755a8ef55f4b977ef8f60a89e/stable_diffusion/ldm/data/webdatasets.py, Line 55
|
||||
# Reference (params): https://github.com/mlcommons/training/blob/ab4ae1ca718d7fe62c369710a316dff18768d04b/stable_diffusion/configs/train_01x08x08.yaml, Line 107
|
||||
def batch_load_train_stable_diffusion(urls:str, BS:int):
|
||||
import webdataset
|
||||
dataset = webdataset.WebDataset(urls=urls, resampled=True, cache_size=-1, cache_dir=None)
|
||||
dataset = dataset.shuffle(size=1000)
|
||||
dataset = dataset.decode()
|
||||
dataset = dataset.map(filter_dataset)
|
||||
dataset = dataset.batched(BS, partial=False, collation_fn=collate)
|
||||
dataset = webdataset.WebLoader(dataset, batch_size=None, shuffle=False, num_workers=1, persistent_workers=True, collate_fn=collate_fn)
|
||||
|
||||
for x in dataset:
|
||||
assert isinstance(x, dict) and all(isinstance(k, str) for k in x.keys()) and all(isinstance(v, list) for v in x.values())
|
||||
assert all(isinstance(moment_mean_logvar, np.ndarray) and moment_mean_logvar.shape==(1,8,64,64) for moment_mean_logvar in x["npy"])
|
||||
assert all(isinstance(caption, str) for caption in x["txt"])
|
||||
yield x
|
||||
|
||||
# llama3
|
||||
|
||||
class BinIdxDataset:
|
||||
|
||||
@@ -1493,6 +1493,144 @@ def train_llama3():
|
||||
safe_save(get_state_dict(model), fn)
|
||||
break
|
||||
|
||||
def train_stable_diffusion():
|
||||
from extra.models.unet import UNetModel
|
||||
from examples.mlperf.dataloader import batch_load_train_stable_diffusion
|
||||
from examples.mlperf.lr_schedulers import LambdaLR, LambdaLinearScheduler
|
||||
from examples.mlperf.initializers import init_stable_diffusion
|
||||
from examples.mlperf.helpers import get_training_state
|
||||
import numpy as np
|
||||
|
||||
config = {}
|
||||
GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
||||
seed = config["seed"] = getenv("SEED", 12345)
|
||||
# ** hyperparameters **
|
||||
BS = config["BS"] = getenv("BS", 1 * len(GPUS))
|
||||
BASE_LR = config["LEARNING_RATE"] = getenv("LEARNING_RATE", 2.5e-7)
|
||||
# https://github.com/mlcommons/training_policies/blob/cfa99da479b8d5931f7a3c67612d021dfb47510a/training_rules.adoc#benchmark_specific_rules
|
||||
# "Checkpoint must be collected every 512,000 images. CEIL(512000 / global_batch_size) if 512000 is not divisible by GBS."
|
||||
# NOTE: It's inferred that "steps" is the unit for the output of the CEIL formula, based on all other cases of CEIL in the rules
|
||||
CKPT_STEP_INTERVAL = config["CKPT_STEP_INTERVAL"] = getenv("CKPT_STEP_INTERVAL", math.ceil(512_000 / BS))
|
||||
CKPTDIR = config["CKPTDIR"] = Path(getenv("CKPTDIR", "./checkpoints"))
|
||||
DATADIR = config["DATADIR"] = Path(getenv("DATADIR", "./datasets"))
|
||||
UNET_CKPTDIR = config["UNET_CKPTDIR"] = Path(getenv("UNET_CKPTDIR", "./checkpoints"))
|
||||
TOTAL_CKPTS = config["TOTAL_CKPTS"] = getenv("TOTAL_CKPTS", 0)
|
||||
|
||||
print(f"training on {GPUS}")
|
||||
lr = BS * BASE_LR
|
||||
print(f"BS={BS}, BASE_LR={BASE_LR}, lr={lr}")
|
||||
print(f"CKPT_STEP_INTERVAL = {CKPT_STEP_INTERVAL}")
|
||||
for x in GPUS: Device[x]
|
||||
if (WANDB := getenv("WANDB", "")):
|
||||
import wandb
|
||||
wandb.init(config=config, project="MLPerf-Stable-Diffusion")
|
||||
|
||||
Tensor.manual_seed(seed) # seed for weight initialization
|
||||
model, unet, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod = init_stable_diffusion("v2-mlperf-train", CKPTDIR / "sd" / "512-base-ema.ckpt", GPUS)
|
||||
|
||||
optimizer = AdamW(get_parameters(unet))
|
||||
lambda_lr_callback = LambdaLinearScheduler(1000, 1.0, 1.0, 1e-06, 10000000000000).schedule
|
||||
lr_scheduler = LambdaLR(optimizer, Tensor(lr, dtype=dtypes.float, device=optimizer.device), lambda_lr_callback)
|
||||
|
||||
@TinyJit
|
||||
def train_step(mean:Tensor, logvar:Tensor, tokens:Tensor, unet:UNetModel, optimizer:LAMB, lr_scheduler:LambdaLR) -> Tensor:
|
||||
optimizer.zero_grad()
|
||||
|
||||
timestep = Tensor.randint(BS, low=0, high=model.alphas_cumprod.shape[0], dtype=dtypes.int, device=GPUS[0])
|
||||
latent_randn = Tensor.randn(*mean.shape, device=GPUS[0])
|
||||
noise = Tensor.randn(*mean.shape, device=GPUS[0])
|
||||
for t in (mean, logvar, tokens, timestep, latent_randn, noise):
|
||||
t.shard_(GPUS, axis=0)
|
||||
|
||||
std = Tensor.exp(0.5 * logvar.clamp(-30.0, 20.0))
|
||||
latent = (mean + std * latent_randn) * 0.18215
|
||||
|
||||
sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[timestep].reshape(timestep.shape[0], 1, 1, 1)
|
||||
sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[timestep].reshape(timestep.shape[0], 1, 1, 1)
|
||||
latent_with_noise = sqrt_alphas_cumprod_t * latent + sqrt_one_minus_alphas_cumprod_t * noise
|
||||
v_true = sqrt_alphas_cumprod_t * noise - sqrt_one_minus_alphas_cumprod_t * latent
|
||||
|
||||
context = model.cond_stage_model.embed_tokens(tokens)
|
||||
|
||||
out = unet(latent_with_noise, timestep, context)
|
||||
loss = ((out - v_true) ** 2).mean()
|
||||
del mean, logvar, std, latent, noise, sqrt_alphas_cumprod_t, sqrt_one_minus_alphas_cumprod_t
|
||||
del out, v_true, context, latent_randn, tokens, timestep
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
loss, out_lr = loss.detach().to("CPU"), optimizer.lr.to("CPU")
|
||||
Tensor.realize(loss, out_lr)
|
||||
return loss, out_lr
|
||||
|
||||
# checkpointing takes ~9 minutes without this, and ~1 minute with this
|
||||
@TinyJit
|
||||
def ckpt_to_cpu():
|
||||
ckpt = get_training_state(unet, optimizer, lr_scheduler)
|
||||
# move to CPU first so more GPU bufs aren't created (can trigger OOM)
|
||||
for k,v in ckpt.items(): ckpt[k] = v.detach().to("CPU")
|
||||
Tensor.realize(*[v for v in ckpt.values()])
|
||||
for k,v in ckpt.items(): ckpt[k] = v.cast(v.dtype.base).contiguous()
|
||||
Tensor.realize(*[v for v in ckpt.values()])
|
||||
return ckpt
|
||||
|
||||
# training loop
|
||||
dl = batch_load_train_stable_diffusion(f'{DATADIR}/laion-400m/webdataset-moments-filtered/{{00000..00831}}.tar', BS)
|
||||
# for tests
|
||||
saved_checkpoints = []
|
||||
|
||||
train_start_time = time.perf_counter()
|
||||
t0 = t6 = time.perf_counter()
|
||||
for i, batch in enumerate(dl, start=1):
|
||||
loop_time = time.perf_counter() - t0
|
||||
t0 = time.perf_counter()
|
||||
dl_time = t0 - t6
|
||||
GlobalCounters.reset()
|
||||
|
||||
mean, logvar = np.split(np.concatenate(batch["npy"], axis=0), 2, axis=1)
|
||||
mean, logvar = Tensor(mean, dtype=dtypes.float32, device="CPU"), Tensor(logvar, dtype=dtypes.float32, device="CPU")
|
||||
tokens = []
|
||||
for text in batch['txt']: tokens += model.cond_stage_model.tokenizer.encode(text, pad_with_zeros=True)
|
||||
tokens = Tensor(tokens, dtype=dtypes.int32, device="CPU").reshape(-1, 77)
|
||||
|
||||
t1 = time.perf_counter()
|
||||
loss, lr = train_step(mean, logvar, tokens, unet, optimizer, lr_scheduler)
|
||||
loss_item, lr_item = loss.item(), lr.item()
|
||||
t2 = time.perf_counter()
|
||||
|
||||
if i == 3:
|
||||
for _ in range(3): ckpt_to_cpu() # do this at the beginning of run to prevent OOM surprises when checkpointing
|
||||
print("BEAM COMPLETE", flush=True) # allows wrapper script to detect BEAM search completion and retry if it failed
|
||||
|
||||
total_train_time = time.perf_counter() - train_start_time
|
||||
if WANDB:
|
||||
wandb.log({"train/loss": loss_item, "train/lr": lr_item, "train/loop_time_prev": loop_time, "train/dl_time": dl_time, "train/step": i,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (t2-t1), "train/input_prep_time": t1-t0,
|
||||
"train/train_step_time": t2-t1, "train/total_time": total_train_time})
|
||||
|
||||
if i == 1 and wandb.run is not None:
|
||||
with open(f"{UNET_CKPTDIR}/wandb_run_id_{wandb.run.id}", "w") as f:
|
||||
f.write(f"wandb.run.id = {wandb.run.id}")
|
||||
|
||||
if i % CKPT_STEP_INTERVAL == 0:
|
||||
# https://github.com/mlcommons/training_policies/blob/cfa99da479b8d5931f7a3c67612d021dfb47510a/training_rules.adoc#benchmark_specific_rules
|
||||
# "evaluation is done offline, the time is not counted towards the submission time."
|
||||
fn = f"{UNET_CKPTDIR}/{i}.safetensors"
|
||||
print(f"saving unet checkpoint at {fn}")
|
||||
saved_checkpoints.append(fn)
|
||||
safe_save({k.replace("model.", ""):v for k,v in ckpt_to_cpu().items() if k.startswith("model.")}, fn)
|
||||
if TOTAL_CKPTS and i == TOTAL_CKPTS * CKPT_STEP_INTERVAL:
|
||||
print(f"ending run after {i} steps ({TOTAL_CKPTS} checkpoints collected)")
|
||||
return saved_checkpoints
|
||||
|
||||
t3 = time.perf_counter()
|
||||
print(f"""step {i}: {GlobalCounters.global_ops * 1e-9 / (t2-t1):9.2f} GFLOPS, mem_used: {GlobalCounters.mem_used / 1e9:.2f} GB,
|
||||
loop_time_prev: {loop_time:.2f}, dl_time: {dl_time:.2f}, input_prep_time: {t1-t0:.2f}, train_step_time: {t2-t1:.2f},
|
||||
t3-t2: {t3-t2:.4f}, loss:{loss_item:.5f}, lr:{lr_item:.3e}, total_train_time:{total_train_time:.2f}
|
||||
""")
|
||||
t6 = time.perf_counter()
|
||||
|
||||
if __name__ == "__main__":
|
||||
multiprocessing.set_start_method('spawn')
|
||||
|
||||
@@ -1501,7 +1639,7 @@ if __name__ == "__main__":
|
||||
else: bench_log_manager = contextlib.nullcontext()
|
||||
|
||||
with Tensor.train():
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn,stable_diffusion").split(","):
|
||||
nm = f"train_{m}"
|
||||
if nm in globals():
|
||||
print(f"training {m}")
|
||||
|
||||
23
test/external/mlperf_stable_diffusion/external_test_train.py
vendored
Normal file
23
test/external/mlperf_stable_diffusion/external_test_train.py
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
import unittest, os
|
||||
from tempfile import TemporaryDirectory
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from examples.mlperf.model_train import train_stable_diffusion
|
||||
|
||||
class TestTrain(unittest.TestCase):
|
||||
def test_train_to_ckpt(self):
|
||||
# train for num_steps, save checkpoint, and stop training
|
||||
num_steps = 42
|
||||
os.environ.update({"MODEL": "stable_diffusion", "TOTAL_CKPTS": "1", "CKPT_STEP_INTERVAL": str(num_steps), "GPUS": "8", "BS": "304"})
|
||||
# NOTE: update these based on where data/checkpoints are on your system
|
||||
if not getenv("DATADIR", ""): os.environ["DATADIR"] = "/raid/datasets/stable_diffusion"
|
||||
if not getenv("CKPTDIR", ""): os.environ["CKPTDIR"] = "/raid/weights/stable_diffusion"
|
||||
with TemporaryDirectory(prefix="test-train") as tmp:
|
||||
os.environ["UNET_CKPTDIR"] = tmp
|
||||
with Tensor.train():
|
||||
saved_ckpts = train_stable_diffusion()
|
||||
expected_ckpt = f"{tmp}/{num_steps}.safetensors"
|
||||
assert len(saved_ckpts) == 1 and saved_ckpts[0] == expected_ckpt
|
||||
|
||||
if __name__=="__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user