* update test_clone_doesnt_dedup to use base
* new_flat_buffer passes
* fix test_reorder_expand
* remove the view stuff
* remove that test, we don't want this view const behavior
* test_setitem_becomes_subbuffer is good
* remove skipping cast in simplify_valid [pr]
unsupported statements are handled in uop_given_valid already. the test failed because (100%x) somehow got simplified
* better test
* entrypoint for sd mlperf train development
* match sd-v2 mlperf reference unet
* implement dataloader from mlperf ref
* update dataloader reference
* implement LambdaLR scheduler from mlperf ref
* match tokenizer from mlperf reference
* sample latent
* add noise to latent
* complete training epoch
* run full training step
* jit training loop
* replicate mlperf ref. losses over 11 train steps
* save tinygrad loss checkpoints properly
* match out.2.bias.grad to reference
* match weights to ref after 1 step
* compare out.2.bias to ref over three train steps
* implement attn_mask; cleanup closeness testing
* correct mse loss
* update dev_run / dependencies
* setup validation config/checkpointing
* implement validation sampling
* test closeness of eval denoise step to mlperf ref
* test closeness of decoder to mlperf ref
* confirm inception matches mlperf ref
* resize w/ bicubic interpolation, test closeness
* confirm closeness of clip preprocess to mlperf ref
* confirm clip score matches mlperf ref
* confirm fid/clip scores match mlperf ref
* cleanup
* cleanup
* zero-init some unet params as in mlperf reference
* revert jit change
* uncomment dependencies
* move to tinybox red
* implement GradScaler from torch but jittable
* simplify lr_scheduler, ensure jittability
* instantiate GradScaler
* only check if grads are finite with fp16
* implement fp16 training loop
* refactor UNet: norm, gelu, mixed precision
* refactor clip_tokenizer to enable versioning
* make fp16 attention closer to torch
* remove comparisons to torch fp16 attention
* add globvars.py for reference
* confirm closeness of fp16 unet forward to mlperf
* test norm closeness to torch with precast
* remeasure e2e with master attention
* more detailed softmax upcast comparison to torch
* parameterize softmax upcast in attention and unet
* use fp32 weights with autocast to fp16
* cleanup
* add data/checkpoint download script
* debug kernel timeout on AMD
* fix finite grads check; start multigpu
* pass numpy arrays from dataloader
* include text encoder in jit train step
* use int32 for tokens instead of int64
* prevent multi bug in reshape within clip
* corealize more, del refs before
* add more logging and wandb
* use erf gelu in clip encoder
* minor changes to train step and logging
* save checkpoints for eval or resuming
* add eval-only logic to training script
* multigpu eval
* remove PARALLEL=0
* cleanup
* pad eval batches of size < EVAL_BS
* workaround silent multigpu bug in jit
* cleanup
* tokenize captions
* verify correctness of multigpu eval
* cleanup
* verify correctness of grads in train step
* verify correctness of training (20 steps)
* don't shard in the training jit
* training settings
* minor cleanup
* overfit train w/ eval on 6 samples
* offload to enable combined train and eval
* download to raid; use local rclone
* misc changes for mi300x / logging
* refactor eval for larger BS, verify correctness
* cleanup
* ckpt resuming and remove eval cats
* eval BEAM config on mi300x and red
* resume eval after crash
* confirm eval correctness (one iteration, 6 samples)
* verify eval correctness at full scale
* cleanup correctness testing
* training correctness (20 steps, BS=248 uniform)
* cleanup
* remove eval cache at end of run
* switch f16 for bf16, del grad scaler
* confirm bf16 training correctness
* timestamps, new jits
* merge jits in training
* realize loss/lr on CPU
* training correctness
* post-bf16 train/eval
* implement grad_acc with timing/logging
* beam offline; debug gradacc; use float32
* fix gradacc in jit, correctness test
* prepare f32 BS=512 gradacc=4 run
* workaround jit problem in diffusion eval
* scale lr by BS
* revert gradacc, prepare bf16 BS=336 lr*=BS train
* make checkpointing faster
* resume bf16 BS=336 base_lr=1.25e-7 run
* jit ckpt at beginning
* don't alloc more gpu mem in ckpt
* cleanup
* move script to mi300x dir
* cleanup
* cleanup unneeded files
* revert beam search to master
* minor changes
* fix regression: realize before assign in eval
* cleanup mlperf SD data/ckpt downloads
* workaround BEAM failure
* workaround bug in Tensor.stack
* minor changes
* revert gradscaler
* cleanup
* cleanup/validate dataloader
* ensure checksum of laion data
* simplify config
* load training state to jitted bufs
* simplify lr scheduler
* simplify train script
* cleanup comments
* refactor stable diffusion/unet init
* more refactoring of stable diffusion init
* fix import errors in tests
* refactor: separate train/eval
* fix import errors
* eval checkpoints in reverse chron. order
* save/load cycle in sd init
* refactor and verify eval
* verify training correctness
* prepare repro train run
* cleanup
* integrate beam retry, train, eval
* simplify wandb
* kill orphaned processes
* better logging
* train to 10 ckpts instead of 7
* remove optimizer/scheduler checkpointing/resume
* cleanup
* BEAM=2 7 ckpts
* add test to compare with torch softmax in amp
* cleanup
* stop eval early if checkpoint converged
* add test for lr scheduler
* add proper test method
* add test for training
* use venv name that is ignored by .gitignore
* linting
* add simple f32 softmax fxn
* revert change to scaled_dot_product_attention
* refactor gelu_erf init
* simplify mixed precision in unet
* add norm autocasting to fp32
* rm extra test
* test eval with NULL backend
* fix venv name
* simplify norm autocast
* use temp dir for training test
* actually add eval test
* remove parallel env variable from tests
* update clip with tests
* reorg init functions
* use np for testing
* remove unused var
* factor out GPUS
* add sd model init tests
* more unet tests
* match master
* rerun CI due to linux (remote) hang
* explain UNET_CKPTDIR
* rerun CI due to linux (remote) timeout
---------
Co-authored-by: chenyu <chenyu@fastmail.com>
* fix bmnist torch with RANGEIFY=1
* alt
* test and comment
* this was always wrong
* simple failing test for rangeify
* simple upat to match the old behavior
* add ordering
* fix some tests
* fix more tests
* shorten comment
* update test
* add rule and test
* add rule and test
* remove check
* use fold_divmod_congruence instead of simplify
* adjust tests
* shorten line
* new algo
* add test
* add function to un-nest the div
* add UOp.factor
* test UOp.factor
* uop_given_valid tries to factor simplex expression
* shorten line
* symbolic_flat is back
* change that back
* fix those new tests
* new rule for ordering
* factor multiple factors
* no symbolic_flat
* symbolic_flat to there
* move that back
* fix imports
* merge correctly
* linter happy
* add rule
* add a test
* cleanup
* revert that for now
* UOp.factor returns self instead of None
* try all_candidates
* remove or_else
* post index symbolic
* add test
* maket this closer to the original
* increase mac hlb_cifar min step time
* add some ordering tests
* cleanup
* increase pytest timeout time
* check dtype