mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Stable Diffusion mlperf training (#11304)
* entrypoint for sd mlperf train development * match sd-v2 mlperf reference unet * implement dataloader from mlperf ref * update dataloader reference * implement LambdaLR scheduler from mlperf ref * match tokenizer from mlperf reference * sample latent * add noise to latent * complete training epoch * run full training step * jit training loop * replicate mlperf ref. losses over 11 train steps * save tinygrad loss checkpoints properly * match out.2.bias.grad to reference * match weights to ref after 1 step * compare out.2.bias to ref over three train steps * implement attn_mask; cleanup closeness testing * correct mse loss * update dev_run / dependencies * setup validation config/checkpointing * implement validation sampling * test closeness of eval denoise step to mlperf ref * test closeness of decoder to mlperf ref * confirm inception matches mlperf ref * resize w/ bicubic interpolation, test closeness * confirm closeness of clip preprocess to mlperf ref * confirm clip score matches mlperf ref * confirm fid/clip scores match mlperf ref * cleanup * cleanup * zero-init some unet params as in mlperf reference * revert jit change * uncomment dependencies * move to tinybox red * implement GradScaler from torch but jittable * simplify lr_scheduler, ensure jittability * instantiate GradScaler * only check if grads are finite with fp16 * implement fp16 training loop * refactor UNet: norm, gelu, mixed precision * refactor clip_tokenizer to enable versioning * make fp16 attention closer to torch * remove comparisons to torch fp16 attention * add globvars.py for reference * confirm closeness of fp16 unet forward to mlperf * test norm closeness to torch with precast * remeasure e2e with master attention * more detailed softmax upcast comparison to torch * parameterize softmax upcast in attention and unet * use fp32 weights with autocast to fp16 * cleanup * add data/checkpoint download script * debug kernel timeout on AMD * fix finite grads check; start multigpu * pass numpy arrays from dataloader * include text encoder in jit train step * use int32 for tokens instead of int64 * prevent multi bug in reshape within clip * corealize more, del refs before * add more logging and wandb * use erf gelu in clip encoder * minor changes to train step and logging * save checkpoints for eval or resuming * add eval-only logic to training script * multigpu eval * remove PARALLEL=0 * cleanup * pad eval batches of size < EVAL_BS * workaround silent multigpu bug in jit * cleanup * tokenize captions * verify correctness of multigpu eval * cleanup * verify correctness of grads in train step * verify correctness of training (20 steps) * don't shard in the training jit * training settings * minor cleanup * overfit train w/ eval on 6 samples * offload to enable combined train and eval * download to raid; use local rclone * misc changes for mi300x / logging * refactor eval for larger BS, verify correctness * cleanup * ckpt resuming and remove eval cats * eval BEAM config on mi300x and red * resume eval after crash * confirm eval correctness (one iteration, 6 samples) * verify eval correctness at full scale * cleanup correctness testing * training correctness (20 steps, BS=248 uniform) * cleanup * remove eval cache at end of run * switch f16 for bf16, del grad scaler * confirm bf16 training correctness * timestamps, new jits * merge jits in training * realize loss/lr on CPU * training correctness * post-bf16 train/eval * implement grad_acc with timing/logging * beam offline; debug gradacc; use float32 * fix gradacc in jit, correctness test * prepare f32 BS=512 gradacc=4 run * workaround jit problem in diffusion eval * scale lr by BS * revert gradacc, prepare bf16 BS=336 lr*=BS train * make checkpointing faster * resume bf16 BS=336 base_lr=1.25e-7 run * jit ckpt at beginning * don't alloc more gpu mem in ckpt * cleanup * move script to mi300x dir * cleanup * cleanup unneeded files * revert beam search to master * minor changes * fix regression: realize before assign in eval * cleanup mlperf SD data/ckpt downloads * workaround BEAM failure * workaround bug in Tensor.stack * minor changes * revert gradscaler * cleanup * cleanup/validate dataloader * ensure checksum of laion data * simplify config * load training state to jitted bufs * simplify lr scheduler * simplify train script * cleanup comments * refactor stable diffusion/unet init * more refactoring of stable diffusion init * fix import errors in tests * refactor: separate train/eval * fix import errors * eval checkpoints in reverse chron. order * save/load cycle in sd init * refactor and verify eval * verify training correctness * prepare repro train run * cleanup * integrate beam retry, train, eval * simplify wandb * kill orphaned processes * better logging * train to 10 ckpts instead of 7 * remove optimizer/scheduler checkpointing/resume * cleanup * BEAM=2 7 ckpts * add test to compare with torch softmax in amp * cleanup * stop eval early if checkpoint converged * add test for lr scheduler * add proper test method * add test for training * use venv name that is ignored by .gitignore * linting * add simple f32 softmax fxn * revert change to scaled_dot_product_attention * refactor gelu_erf init * simplify mixed precision in unet * add norm autocasting to fp32 * rm extra test * test eval with NULL backend * fix venv name * simplify norm autocast * use temp dir for training test * actually add eval test * remove parallel env variable from tests * update clip with tests * reorg init functions * use np for testing * remove unused var * factor out GPUS * add sd model init tests * more unet tests * match master * rerun CI due to linux (remote) hang * explain UNET_CKPTDIR * rerun CI due to linux (remote) timeout --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
DATETIME=${2:-$(date "+%m%d%H%M")}
|
||||
LOGFILE="${HOME}/logs/sd_mi300x_${DATETIME}.log"
|
||||
# UNET_CKPTDIR must be set: training saves checkpoints to this path, then a separate eval process scans this path to know which checkpoints to eval
|
||||
export UNET_CKPTDIR="${HOME}/stable_diffusion/training_checkpoints/${DATETIME}"
|
||||
mkdir -p "${HOME}/logs" "$UNET_CKPTDIR"
|
||||
|
||||
# run this script in isolation when using the --bg flag
|
||||
if [[ "${1:-}" == "--bg" ]]; then
|
||||
echo "logging output to $LOGFILE"
|
||||
echo "saving UNet checkpoints to $UNET_CKPTDIR"
|
||||
script_path="$(readlink -f "${BASH_SOURCE[0]}")"
|
||||
nohup bash "$script_path" run "$DATETIME" >"$LOGFILE" 2>&1 & disown $!
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# venv management
|
||||
if [[ -d .venv-sd-mlperf ]]; then
|
||||
. .venv-sd-mlperf/bin/activate
|
||||
else
|
||||
python3 -m venv .venv-sd-mlperf && . .venv-sd-mlperf/bin/activate
|
||||
pip install --index-url https://download.pytorch.org/whl/cpu torch && pip install tqdm numpy ftfy regex pillow scipy wandb webdataset
|
||||
fi
|
||||
pip list
|
||||
apt list --installed | grep amdgpu
|
||||
rocm-smi --version
|
||||
modinfo amdgpu | grep version
|
||||
|
||||
export BEAM=2 BEAM_UOPS_MAX=8000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 BEAM_MIN_PROGRESS=5 IGNORE_JIT_FIRST_BEAM=1 HCQDEV_WAIT_TIMEOUT_MS=300000
|
||||
export AMD_LLVM=0 # bf16 seems to require this
|
||||
export DATADIR="/raid/datasets/stable_diffusion"
|
||||
export CKPTDIR="/raid/weights/stable_diffusion"
|
||||
export EVAL_CKPT_DIR=$UNET_CKPTDIR
|
||||
export MODEL="stable_diffusion" PYTHONPATH="."
|
||||
export GPUS=8 BS=304
|
||||
export CONTEXT_BS=816 DENOISE_BS=600 DECODE_BS=384 INCEPTION_BS=560 CLIP_BS=240
|
||||
export WANDB=1
|
||||
export PARALLEL=4
|
||||
export PYTHONUNBUFFERED=1
|
||||
sudo rocm-smi -d 0 1 2 3 4 5 6 7 --setperfdeterminism 1500 || exit 1
|
||||
|
||||
# Retry BEAM search if script fails before BEAM COMPLETE is printed, but don't retry after that
|
||||
run_retry(){ local try=0 max=5 code tmp py pgid kids
|
||||
while :; do
|
||||
tmp=$(mktemp)
|
||||
setsid bash -c 'exec env "$@"' _ "$@" > >(tee -a "$LOGFILE" | tee "$tmp") 2>&1 &
|
||||
py=$!; pgid=$(ps -o pgid= -p "$py" | tr -d ' ')
|
||||
wait "$py"; code=$?
|
||||
[[ -n "$pgid" ]] && { kill -TERM -"$pgid" 2>/dev/null; sleep 1; kill -KILL -"$pgid" 2>/dev/null; }
|
||||
kids=$(pgrep -P "$py" || true)
|
||||
while [[ -n "$kids" ]]; do
|
||||
kill -TERM $kids 2>/dev/null; sleep 0.5
|
||||
kids=$(for k in $kids; do pgrep -P "$k" || true; done)
|
||||
done
|
||||
grep -q 'BEAM COMPLETE' "$tmp" && { rm -f "$tmp"; return 1; }
|
||||
rm -f "$tmp"
|
||||
((code==0)) && return 0
|
||||
((try>=max)) && return 2
|
||||
((try++)); sleep 90; echo "try = ${try}"
|
||||
done
|
||||
}
|
||||
|
||||
# Power limiting to 400W is only needed if GPUs fall out of sync (causing 2.2x increased train time) at higher power, which has been observed at 450W
|
||||
sudo rocm-smi -d 0 1 2 3 4 5 6 7 --setpoweroverdrive 750 && \
|
||||
run_retry TOTAL_CKPTS=7 python3 examples/mlperf/model_train.py; (( $? == 2 )) && { echo "training failed before BEAM completion"; exit 2; }
|
||||
sleep 90
|
||||
|
||||
run_retry EVAL_SAMPLES=600 python3 examples/mlperf/model_eval.py; (( $? == 2 )) && { echo "eval failed before BEAM completion"; exit 2; }
|
||||
# Checkpoints will be evaluated in reverse chronological order, even if above training crashed early
|
||||
# STOP_IF_CONVERGED=1: Stop the eval after the first time convergence is detected; no more checkpoints will be evaluated after that.
|
||||
STOP_IF_CONVERGED=1 python3 examples/mlperf/model_eval.py
|
||||
Reference in New Issue
Block a user