Stable Diffusion mlperf training (#11304)

* entrypoint for sd mlperf train development

* match sd-v2 mlperf reference unet

* implement dataloader from mlperf ref

* update dataloader reference

* implement LambdaLR scheduler from mlperf ref

* match tokenizer from mlperf reference

* sample latent

* add noise to latent

* complete training epoch

* run full training step

* jit training loop

* replicate mlperf ref. losses over 11 train steps

* save tinygrad loss checkpoints properly

* match out.2.bias.grad to reference

* match weights to ref after 1 step

* compare out.2.bias to ref over three train steps

* implement attn_mask; cleanup closeness testing

* correct mse loss

* update dev_run / dependencies

* setup validation config/checkpointing

* implement validation sampling

* test closeness of eval denoise step to mlperf ref

* test closeness of decoder to mlperf ref

* confirm inception matches mlperf ref

* resize w/ bicubic interpolation, test closeness

* confirm closeness of clip preprocess to mlperf ref

* confirm clip score matches mlperf ref

* confirm fid/clip scores match mlperf ref

* cleanup

* cleanup

* zero-init some unet params as in mlperf reference

* revert jit change

* uncomment dependencies

* move to tinybox red

* implement GradScaler from torch but jittable

* simplify lr_scheduler, ensure jittability

* instantiate GradScaler

* only check if grads are finite with fp16

* implement fp16 training loop

* refactor UNet: norm, gelu, mixed precision

* refactor clip_tokenizer to enable versioning

* make fp16 attention closer to torch

* remove comparisons to torch fp16 attention

* add globvars.py for reference

* confirm closeness of fp16 unet forward to mlperf

* test norm closeness to torch with precast

* remeasure e2e with master attention

* more detailed softmax upcast comparison to torch

* parameterize softmax upcast in attention and unet

* use fp32 weights with autocast to fp16

* cleanup

* add data/checkpoint download script

* debug kernel timeout on AMD

* fix finite grads check; start multigpu

* pass numpy arrays from dataloader

* include text encoder in jit train step

* use int32 for tokens instead of int64

* prevent multi bug in reshape within clip

* corealize more, del refs before

* add more logging and wandb

* use erf gelu in clip encoder

* minor changes to train step and logging

* save checkpoints for eval or resuming

* add eval-only logic to training script

* multigpu eval

* remove PARALLEL=0

* cleanup

* pad eval batches of size < EVAL_BS

* workaround silent multigpu bug in jit

* cleanup

* tokenize captions

* verify correctness of multigpu eval

* cleanup

* verify correctness of grads in train step

* verify correctness of training (20 steps)

* don't shard in the training jit

* training settings

* minor cleanup

* overfit train w/ eval on 6 samples

* offload to enable combined train and eval

* download to raid; use local rclone

* misc changes for mi300x / logging

* refactor eval for larger BS, verify correctness

* cleanup

* ckpt resuming and remove eval cats

* eval BEAM config on mi300x and red

* resume eval after crash

* confirm eval correctness (one iteration, 6 samples)

* verify eval correctness at full scale

* cleanup correctness testing

* training correctness (20 steps, BS=248 uniform)

* cleanup

* remove eval cache at end of run

* switch f16 for bf16, del grad scaler

* confirm bf16 training correctness

* timestamps, new jits

* merge jits in training

* realize loss/lr on CPU

* training correctness

* post-bf16 train/eval

* implement grad_acc with timing/logging

* beam offline; debug gradacc; use float32

* fix gradacc in jit, correctness test

* prepare f32 BS=512 gradacc=4 run

* workaround jit problem in diffusion eval

* scale lr by BS

* revert gradacc, prepare bf16 BS=336 lr*=BS train

* make checkpointing faster

* resume bf16 BS=336 base_lr=1.25e-7 run

* jit ckpt at beginning

* don't alloc more gpu mem in ckpt

* cleanup

* move script to mi300x dir

* cleanup

* cleanup unneeded files

* revert beam search to master

* minor changes

* fix regression: realize before assign in eval

* cleanup mlperf SD data/ckpt downloads

* workaround BEAM failure

* workaround bug in Tensor.stack

* minor changes

* revert gradscaler

* cleanup

* cleanup/validate dataloader

* ensure checksum of laion data

* simplify config

* load training state to jitted bufs

* simplify lr scheduler

* simplify train script

* cleanup comments

* refactor stable diffusion/unet init

* more refactoring of stable diffusion init

* fix import errors in tests

* refactor: separate train/eval

* fix import errors

* eval checkpoints in reverse chron. order

* save/load cycle in sd init

* refactor and verify eval

* verify training correctness

* prepare repro train run

* cleanup

* integrate beam retry, train, eval

* simplify wandb

* kill orphaned processes

* better logging

* train to 10 ckpts instead of 7

* remove optimizer/scheduler checkpointing/resume

* cleanup

* BEAM=2 7 ckpts

* add test to compare with torch softmax in amp

* cleanup

* stop eval early if checkpoint converged

* add test for lr scheduler

* add proper test method

* add test for training

* use venv name that is ignored by .gitignore

* linting

* add simple f32 softmax fxn

* revert change to scaled_dot_product_attention

* refactor gelu_erf init

* simplify mixed precision in unet

* add norm autocasting to fp32

* rm extra test

* test eval with NULL backend

* fix venv name

* simplify norm autocast

* use temp dir for training test

* actually add eval test

* remove parallel env variable from tests

* update clip with tests

* reorg init functions

* use np for testing

* remove unused var

* factor out GPUS

* add sd model init tests

* more unet tests

* match master

* rerun CI due to linux (remote) hang

* explain UNET_CKPTDIR

* rerun CI due to linux (remote) timeout

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
hooved
2025-10-05 07:56:05 -04:00
committed by GitHub
parent a976ace404
commit 69857d0ab0

View File

@@ -0,0 +1,72 @@
#!/usr/bin/env bash
DATETIME=${2:-$(date "+%m%d%H%M")}
LOGFILE="${HOME}/logs/sd_mi300x_${DATETIME}.log"
# UNET_CKPTDIR must be set: training saves checkpoints to this path, then a separate eval process scans this path to know which checkpoints to eval
export UNET_CKPTDIR="${HOME}/stable_diffusion/training_checkpoints/${DATETIME}"
mkdir -p "${HOME}/logs" "$UNET_CKPTDIR"
# run this script in isolation when using the --bg flag
if [[ "${1:-}" == "--bg" ]]; then
echo "logging output to $LOGFILE"
echo "saving UNet checkpoints to $UNET_CKPTDIR"
script_path="$(readlink -f "${BASH_SOURCE[0]}")"
nohup bash "$script_path" run "$DATETIME" >"$LOGFILE" 2>&1 & disown $!
exit 0
fi
# venv management
if [[ -d .venv-sd-mlperf ]]; then
. .venv-sd-mlperf/bin/activate
else
python3 -m venv .venv-sd-mlperf && . .venv-sd-mlperf/bin/activate
pip install --index-url https://download.pytorch.org/whl/cpu torch && pip install tqdm numpy ftfy regex pillow scipy wandb webdataset
fi
pip list
apt list --installed | grep amdgpu
rocm-smi --version
modinfo amdgpu | grep version
export BEAM=2 BEAM_UOPS_MAX=8000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 IGNORE_JIT_FIRST_BEAM=1 HCQDEV_WAIT_TIMEOUT_MS=300000
export AMD_LLVM=0 # bf16 seems to require this
export DATADIR="/raid/datasets/stable_diffusion"
export CKPTDIR="/raid/weights/stable_diffusion"
export EVAL_CKPT_DIR=$UNET_CKPTDIR
export MODEL="stable_diffusion" PYTHONPATH="."
export GPUS=8 BS=304
export CONTEXT_BS=816 DENOISE_BS=600 DECODE_BS=384 INCEPTION_BS=560 CLIP_BS=240
export WANDB=1
export PARALLEL=4
export PYTHONUNBUFFERED=1
sudo rocm-smi -d 0 1 2 3 4 5 6 7 --setperfdeterminism 1500 || exit 1
# Retry BEAM search if script fails before BEAM COMPLETE is printed, but don't retry after that
run_retry(){ local try=0 max=5 code tmp py pgid kids
while :; do
tmp=$(mktemp)
setsid bash -c 'exec env "$@"' _ "$@" > >(tee -a "$LOGFILE" | tee "$tmp") 2>&1 &
py=$!; pgid=$(ps -o pgid= -p "$py" | tr -d ' ')
wait "$py"; code=$?
[[ -n "$pgid" ]] && { kill -TERM -"$pgid" 2>/dev/null; sleep 1; kill -KILL -"$pgid" 2>/dev/null; }
kids=$(pgrep -P "$py" || true)
while [[ -n "$kids" ]]; do
kill -TERM $kids 2>/dev/null; sleep 0.5
kids=$(for k in $kids; do pgrep -P "$k" || true; done)
done
grep -q 'BEAM COMPLETE' "$tmp" && { rm -f "$tmp"; return 1; }
rm -f "$tmp"
((code==0)) && return 0
((try>=max)) && return 2
((try++)); sleep 90; echo "try = ${try}"
done
}
# Power limiting to 400W is only needed if GPUs fall out of sync (causing 2.2x increased train time) at higher power, which has been observed at 450W
sudo rocm-smi -d 0 1 2 3 4 5 6 7 --setpoweroverdrive 750 && \
run_retry TOTAL_CKPTS=7 python3 examples/mlperf/model_train.py; (( $? == 2 )) && { echo "training failed before BEAM completion"; exit 2; }
sleep 90
run_retry EVAL_SAMPLES=600 python3 examples/mlperf/model_eval.py; (( $? == 2 )) && { echo "eval failed before BEAM completion"; exit 2; }
# Checkpoints will be evaluated in reverse chronological order, even if above training crashed early
# STOP_IF_CONVERGED=1: Stop the eval after the first time convergence is detected; no more checkpoints will be evaluated after that.
STOP_IF_CONVERGED=1 python3 examples/mlperf/model_eval.py