From 69857d0ab081e293825433d5d58fe4f9a309f084 Mon Sep 17 00:00:00 2001 From: hooved <172129504+hooved@users.noreply.github.com> Date: Sun, 5 Oct 2025 07:56:05 -0400 Subject: [PATCH] Stable Diffusion mlperf training (#11304) * entrypoint for sd mlperf train development * match sd-v2 mlperf reference unet * implement dataloader from mlperf ref * update dataloader reference * implement LambdaLR scheduler from mlperf ref * match tokenizer from mlperf reference * sample latent * add noise to latent * complete training epoch * run full training step * jit training loop * replicate mlperf ref. losses over 11 train steps * save tinygrad loss checkpoints properly * match out.2.bias.grad to reference * match weights to ref after 1 step * compare out.2.bias to ref over three train steps * implement attn_mask; cleanup closeness testing * correct mse loss * update dev_run / dependencies * setup validation config/checkpointing * implement validation sampling * test closeness of eval denoise step to mlperf ref * test closeness of decoder to mlperf ref * confirm inception matches mlperf ref * resize w/ bicubic interpolation, test closeness * confirm closeness of clip preprocess to mlperf ref * confirm clip score matches mlperf ref * confirm fid/clip scores match mlperf ref * cleanup * cleanup * zero-init some unet params as in mlperf reference * revert jit change * uncomment dependencies * move to tinybox red * implement GradScaler from torch but jittable * simplify lr_scheduler, ensure jittability * instantiate GradScaler * only check if grads are finite with fp16 * implement fp16 training loop * refactor UNet: norm, gelu, mixed precision * refactor clip_tokenizer to enable versioning * make fp16 attention closer to torch * remove comparisons to torch fp16 attention * add globvars.py for reference * confirm closeness of fp16 unet forward to mlperf * test norm closeness to torch with precast * remeasure e2e with master attention * more detailed softmax upcast comparison to torch * parameterize softmax upcast in attention and unet * use fp32 weights with autocast to fp16 * cleanup * add data/checkpoint download script * debug kernel timeout on AMD * fix finite grads check; start multigpu * pass numpy arrays from dataloader * include text encoder in jit train step * use int32 for tokens instead of int64 * prevent multi bug in reshape within clip * corealize more, del refs before * add more logging and wandb * use erf gelu in clip encoder * minor changes to train step and logging * save checkpoints for eval or resuming * add eval-only logic to training script * multigpu eval * remove PARALLEL=0 * cleanup * pad eval batches of size < EVAL_BS * workaround silent multigpu bug in jit * cleanup * tokenize captions * verify correctness of multigpu eval * cleanup * verify correctness of grads in train step * verify correctness of training (20 steps) * don't shard in the training jit * training settings * minor cleanup * overfit train w/ eval on 6 samples * offload to enable combined train and eval * download to raid; use local rclone * misc changes for mi300x / logging * refactor eval for larger BS, verify correctness * cleanup * ckpt resuming and remove eval cats * eval BEAM config on mi300x and red * resume eval after crash * confirm eval correctness (one iteration, 6 samples) * verify eval correctness at full scale * cleanup correctness testing * training correctness (20 steps, BS=248 uniform) * cleanup * remove eval cache at end of run * switch f16 for bf16, del grad scaler * confirm bf16 training correctness * timestamps, new jits * merge jits in training * realize loss/lr on CPU * training correctness * post-bf16 train/eval * implement grad_acc with timing/logging * beam offline; debug gradacc; use float32 * fix gradacc in jit, correctness test * prepare f32 BS=512 gradacc=4 run * workaround jit problem in diffusion eval * scale lr by BS * revert gradacc, prepare bf16 BS=336 lr*=BS train * make checkpointing faster * resume bf16 BS=336 base_lr=1.25e-7 run * jit ckpt at beginning * don't alloc more gpu mem in ckpt * cleanup * move script to mi300x dir * cleanup * cleanup unneeded files * revert beam search to master * minor changes * fix regression: realize before assign in eval * cleanup mlperf SD data/ckpt downloads * workaround BEAM failure * workaround bug in Tensor.stack * minor changes * revert gradscaler * cleanup * cleanup/validate dataloader * ensure checksum of laion data * simplify config * load training state to jitted bufs * simplify lr scheduler * simplify train script * cleanup comments * refactor stable diffusion/unet init * more refactoring of stable diffusion init * fix import errors in tests * refactor: separate train/eval * fix import errors * eval checkpoints in reverse chron. order * save/load cycle in sd init * refactor and verify eval * verify training correctness * prepare repro train run * cleanup * integrate beam retry, train, eval * simplify wandb * kill orphaned processes * better logging * train to 10 ckpts instead of 7 * remove optimizer/scheduler checkpointing/resume * cleanup * BEAM=2 7 ckpts * add test to compare with torch softmax in amp * cleanup * stop eval early if checkpoint converged * add test for lr scheduler * add proper test method * add test for training * use venv name that is ignored by .gitignore * linting * add simple f32 softmax fxn * revert change to scaled_dot_product_attention * refactor gelu_erf init * simplify mixed precision in unet * add norm autocasting to fp32 * rm extra test * test eval with NULL backend * fix venv name * simplify norm autocast * use temp dir for training test * actually add eval test * remove parallel env variable from tests * update clip with tests * reorg init functions * use np for testing * remove unused var * factor out GPUS * add sd model init tests * more unet tests * match master * rerun CI due to linux (remote) hang * explain UNET_CKPTDIR * rerun CI due to linux (remote) timeout --------- Co-authored-by: chenyu --- .../tinybox_8xMI300X/dev_run.sh | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100755 examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/stable_diffusion/implementations/tinybox_8xMI300X/dev_run.sh diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/stable_diffusion/implementations/tinybox_8xMI300X/dev_run.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/stable_diffusion/implementations/tinybox_8xMI300X/dev_run.sh new file mode 100755 index 0000000000..5e35ff65a4 --- /dev/null +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/stable_diffusion/implementations/tinybox_8xMI300X/dev_run.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash + +DATETIME=${2:-$(date "+%m%d%H%M")} +LOGFILE="${HOME}/logs/sd_mi300x_${DATETIME}.log" +# UNET_CKPTDIR must be set: training saves checkpoints to this path, then a separate eval process scans this path to know which checkpoints to eval +export UNET_CKPTDIR="${HOME}/stable_diffusion/training_checkpoints/${DATETIME}" +mkdir -p "${HOME}/logs" "$UNET_CKPTDIR" + +# run this script in isolation when using the --bg flag +if [[ "${1:-}" == "--bg" ]]; then + echo "logging output to $LOGFILE" + echo "saving UNet checkpoints to $UNET_CKPTDIR" + script_path="$(readlink -f "${BASH_SOURCE[0]}")" + nohup bash "$script_path" run "$DATETIME" >"$LOGFILE" 2>&1 & disown $! + exit 0 +fi + +# venv management +if [[ -d .venv-sd-mlperf ]]; then + . .venv-sd-mlperf/bin/activate +else + python3 -m venv .venv-sd-mlperf && . .venv-sd-mlperf/bin/activate + pip install --index-url https://download.pytorch.org/whl/cpu torch && pip install tqdm numpy ftfy regex pillow scipy wandb webdataset +fi +pip list +apt list --installed | grep amdgpu +rocm-smi --version +modinfo amdgpu | grep version + +export BEAM=2 BEAM_UOPS_MAX=8000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 IGNORE_JIT_FIRST_BEAM=1 HCQDEV_WAIT_TIMEOUT_MS=300000 +export AMD_LLVM=0 # bf16 seems to require this +export DATADIR="/raid/datasets/stable_diffusion" +export CKPTDIR="/raid/weights/stable_diffusion" +export EVAL_CKPT_DIR=$UNET_CKPTDIR +export MODEL="stable_diffusion" PYTHONPATH="." +export GPUS=8 BS=304 +export CONTEXT_BS=816 DENOISE_BS=600 DECODE_BS=384 INCEPTION_BS=560 CLIP_BS=240 +export WANDB=1 +export PARALLEL=4 +export PYTHONUNBUFFERED=1 +sudo rocm-smi -d 0 1 2 3 4 5 6 7 --setperfdeterminism 1500 || exit 1 + +# Retry BEAM search if script fails before BEAM COMPLETE is printed, but don't retry after that +run_retry(){ local try=0 max=5 code tmp py pgid kids + while :; do + tmp=$(mktemp) + setsid bash -c 'exec env "$@"' _ "$@" > >(tee -a "$LOGFILE" | tee "$tmp") 2>&1 & + py=$!; pgid=$(ps -o pgid= -p "$py" | tr -d ' ') + wait "$py"; code=$? + [[ -n "$pgid" ]] && { kill -TERM -"$pgid" 2>/dev/null; sleep 1; kill -KILL -"$pgid" 2>/dev/null; } + kids=$(pgrep -P "$py" || true) + while [[ -n "$kids" ]]; do + kill -TERM $kids 2>/dev/null; sleep 0.5 + kids=$(for k in $kids; do pgrep -P "$k" || true; done) + done + grep -q 'BEAM COMPLETE' "$tmp" && { rm -f "$tmp"; return 1; } + rm -f "$tmp" + ((code==0)) && return 0 + ((try>=max)) && return 2 + ((try++)); sleep 90; echo "try = ${try}" + done +} + +# Power limiting to 400W is only needed if GPUs fall out of sync (causing 2.2x increased train time) at higher power, which has been observed at 450W +sudo rocm-smi -d 0 1 2 3 4 5 6 7 --setpoweroverdrive 750 && \ +run_retry TOTAL_CKPTS=7 python3 examples/mlperf/model_train.py; (( $? == 2 )) && { echo "training failed before BEAM completion"; exit 2; } +sleep 90 + +run_retry EVAL_SAMPLES=600 python3 examples/mlperf/model_eval.py; (( $? == 2 )) && { echo "eval failed before BEAM completion"; exit 2; } +# Checkpoints will be evaluated in reverse chronological order, even if above training crashed early +# STOP_IF_CONVERGED=1: Stop the eval after the first time convergence is detected; no more checkpoints will be evaluated after that. +STOP_IF_CONVERGED=1 python3 examples/mlperf/model_eval.py