mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
* entrypoint for sd mlperf train development * match sd-v2 mlperf reference unet * implement dataloader from mlperf ref * update dataloader reference * implement LambdaLR scheduler from mlperf ref * match tokenizer from mlperf reference * sample latent * add noise to latent * complete training epoch * run full training step * jit training loop * replicate mlperf ref. losses over 11 train steps * save tinygrad loss checkpoints properly * match out.2.bias.grad to reference * match weights to ref after 1 step * compare out.2.bias to ref over three train steps * implement attn_mask; cleanup closeness testing * correct mse loss * update dev_run / dependencies * setup validation config/checkpointing * implement validation sampling * test closeness of eval denoise step to mlperf ref * test closeness of decoder to mlperf ref * confirm inception matches mlperf ref * resize w/ bicubic interpolation, test closeness * confirm closeness of clip preprocess to mlperf ref * confirm clip score matches mlperf ref * confirm fid/clip scores match mlperf ref * cleanup * cleanup * zero-init some unet params as in mlperf reference * revert jit change * uncomment dependencies * move to tinybox red * implement GradScaler from torch but jittable * simplify lr_scheduler, ensure jittability * instantiate GradScaler * only check if grads are finite with fp16 * implement fp16 training loop * refactor UNet: norm, gelu, mixed precision * refactor clip_tokenizer to enable versioning * make fp16 attention closer to torch * remove comparisons to torch fp16 attention * add globvars.py for reference * confirm closeness of fp16 unet forward to mlperf * test norm closeness to torch with precast * remeasure e2e with master attention * more detailed softmax upcast comparison to torch * parameterize softmax upcast in attention and unet * use fp32 weights with autocast to fp16 * cleanup * add data/checkpoint download script * debug kernel timeout on AMD * fix finite grads check; start multigpu * pass numpy arrays from dataloader * include text encoder in jit train step * use int32 for tokens instead of int64 * prevent multi bug in reshape within clip * corealize more, del refs before * add more logging and wandb * use erf gelu in clip encoder * minor changes to train step and logging * save checkpoints for eval or resuming * add eval-only logic to training script * multigpu eval * remove PARALLEL=0 * cleanup * pad eval batches of size < EVAL_BS * workaround silent multigpu bug in jit * cleanup * tokenize captions * verify correctness of multigpu eval * cleanup * verify correctness of grads in train step * verify correctness of training (20 steps) * don't shard in the training jit * training settings * minor cleanup * overfit train w/ eval on 6 samples * offload to enable combined train and eval * download to raid; use local rclone * misc changes for mi300x / logging * refactor eval for larger BS, verify correctness * cleanup * ckpt resuming and remove eval cats * eval BEAM config on mi300x and red * resume eval after crash * confirm eval correctness (one iteration, 6 samples) * verify eval correctness at full scale * cleanup correctness testing * training correctness (20 steps, BS=248 uniform) * cleanup * remove eval cache at end of run * switch f16 for bf16, del grad scaler * confirm bf16 training correctness * timestamps, new jits * merge jits in training * realize loss/lr on CPU * training correctness * post-bf16 train/eval * implement grad_acc with timing/logging * beam offline; debug gradacc; use float32 * fix gradacc in jit, correctness test * prepare f32 BS=512 gradacc=4 run * workaround jit problem in diffusion eval * scale lr by BS * revert gradacc, prepare bf16 BS=336 lr*=BS train * make checkpointing faster * resume bf16 BS=336 base_lr=1.25e-7 run * jit ckpt at beginning * don't alloc more gpu mem in ckpt * cleanup * move script to mi300x dir * cleanup * cleanup unneeded files * revert beam search to master * minor changes * fix regression: realize before assign in eval * cleanup mlperf SD data/ckpt downloads * workaround BEAM failure * workaround bug in Tensor.stack * minor changes * revert gradscaler * cleanup * cleanup/validate dataloader * ensure checksum of laion data * simplify config * load training state to jitted bufs * simplify lr scheduler * simplify train script * cleanup comments * refactor stable diffusion/unet init * more refactoring of stable diffusion init * fix import errors in tests * refactor: separate train/eval * fix import errors * eval checkpoints in reverse chron. order * save/load cycle in sd init * refactor and verify eval * verify training correctness * prepare repro train run * cleanup * integrate beam retry, train, eval * simplify wandb * kill orphaned processes * better logging * train to 10 ckpts instead of 7 * remove optimizer/scheduler checkpointing/resume * cleanup * BEAM=2 7 ckpts * add test to compare with torch softmax in amp * cleanup * stop eval early if checkpoint converged * add test for lr scheduler * add proper test method * add test for training * use venv name that is ignored by .gitignore * linting * add simple f32 softmax fxn * revert change to scaled_dot_product_attention * refactor gelu_erf init * simplify mixed precision in unet * add norm autocasting to fp32 * rm extra test * test eval with NULL backend * fix venv name * simplify norm autocast * use temp dir for training test * actually add eval test * remove parallel env variable from tests * update clip with tests * reorg init functions * use np for testing * remove unused var * factor out GPUS * add sd model init tests * more unet tests * match master * rerun CI due to linux (remote) hang * explain UNET_CKPTDIR * rerun CI due to linux (remote) timeout --------- Co-authored-by: chenyu <chenyu@fastmail.com>