Commit Graph

294 Commits

Author SHA1 Message Date
chenyu
74db65cf72 update mlperf bert LOGMLPERF (#13065) 2025-11-02 15:26:37 -05:00
chenyu
70dd297a05 BS=96 for bert (#12675)
96 trains fine now
2025-10-14 09:07:43 -04:00
chenyu
77b5e6774e fix bert training config (#12647)
FREE_INTERMEDIATE=0 REWRITE_STACK_LIMIT=500000
2025-10-13 15:03:47 -04:00
chenyu
0f776c6e46 examples/mlperf/training_submission_v6.0 (#12644)
copied from v5.1
2025-10-13 09:58:25 -04:00
chenyu
28edea5d67 delete FUSE_CONV_BW (#12527) 2025-10-08 10:41:38 -04:00
chenyu
e701106a64 remove FUSE_ARANGE (#12511)
it was the default already
2025-10-08 04:54:07 -04:00
hooved
69857d0ab0 Stable Diffusion mlperf training (#11304)
* entrypoint for sd mlperf train development

* match sd-v2 mlperf reference unet

* implement dataloader from mlperf ref

* update dataloader reference

* implement LambdaLR scheduler from mlperf ref

* match tokenizer from mlperf reference

* sample latent

* add noise to latent

* complete training epoch

* run full training step

* jit training loop

* replicate mlperf ref. losses over 11 train steps

* save tinygrad loss checkpoints properly

* match out.2.bias.grad to reference

* match weights to ref after 1 step

* compare out.2.bias to ref over three train steps

* implement attn_mask; cleanup closeness testing

* correct mse loss

* update dev_run / dependencies

* setup validation config/checkpointing

* implement validation sampling

* test closeness of eval denoise step to mlperf ref

* test closeness of decoder to mlperf ref

* confirm inception matches mlperf ref

* resize w/ bicubic interpolation, test closeness

* confirm closeness of clip preprocess to mlperf ref

* confirm clip score matches mlperf ref

* confirm fid/clip scores match mlperf ref

* cleanup

* cleanup

* zero-init some unet params as in mlperf reference

* revert jit change

* uncomment dependencies

* move to tinybox red

* implement GradScaler from torch but jittable

* simplify lr_scheduler, ensure jittability

* instantiate GradScaler

* only check if grads are finite with fp16

* implement fp16 training loop

* refactor UNet: norm, gelu, mixed precision

* refactor clip_tokenizer to enable versioning

* make fp16 attention closer to torch

* remove comparisons to torch fp16 attention

* add globvars.py for reference

* confirm closeness of fp16 unet forward to mlperf

* test norm closeness to torch with precast

* remeasure e2e with master attention

* more detailed softmax upcast comparison to torch

* parameterize softmax upcast in attention and unet

* use fp32 weights with autocast to fp16

* cleanup

* add data/checkpoint download script

* debug kernel timeout on AMD

* fix finite grads check; start multigpu

* pass numpy arrays from dataloader

* include text encoder in jit train step

* use int32 for tokens instead of int64

* prevent multi bug in reshape within clip

* corealize more, del refs before

* add more logging and wandb

* use erf gelu in clip encoder

* minor changes to train step and logging

* save checkpoints for eval or resuming

* add eval-only logic to training script

* multigpu eval

* remove PARALLEL=0

* cleanup

* pad eval batches of size < EVAL_BS

* workaround silent multigpu bug in jit

* cleanup

* tokenize captions

* verify correctness of multigpu eval

* cleanup

* verify correctness of grads in train step

* verify correctness of training (20 steps)

* don't shard in the training jit

* training settings

* minor cleanup

* overfit train w/ eval on 6 samples

* offload to enable combined train and eval

* download to raid; use local rclone

* misc changes for mi300x / logging

* refactor eval for larger BS, verify correctness

* cleanup

* ckpt resuming and remove eval cats

* eval BEAM config on mi300x and red

* resume eval after crash

* confirm eval correctness (one iteration, 6 samples)

* verify eval correctness at full scale

* cleanup correctness testing

* training correctness (20 steps, BS=248 uniform)

* cleanup

* remove eval cache at end of run

* switch f16 for bf16, del grad scaler

* confirm bf16 training correctness

* timestamps, new jits

* merge jits in training

* realize loss/lr on CPU

* training correctness

* post-bf16 train/eval

* implement grad_acc with timing/logging

* beam offline; debug gradacc; use float32

* fix gradacc in jit, correctness test

* prepare f32 BS=512 gradacc=4 run

* workaround jit problem in diffusion eval

* scale lr by BS

* revert gradacc, prepare bf16 BS=336 lr*=BS train

* make checkpointing faster

* resume bf16 BS=336 base_lr=1.25e-7 run

* jit ckpt at beginning

* don't alloc more gpu mem in ckpt

* cleanup

* move script to mi300x dir

* cleanup

* cleanup unneeded files

* revert beam search to master

* minor changes

* fix regression: realize before assign in eval

* cleanup mlperf SD data/ckpt downloads

* workaround BEAM failure

* workaround bug in Tensor.stack

* minor changes

* revert gradscaler

* cleanup

* cleanup/validate dataloader

* ensure checksum of laion data

* simplify config

* load training state to jitted bufs

* simplify lr scheduler

* simplify train script

* cleanup comments

* refactor stable diffusion/unet init

* more refactoring of stable diffusion init

* fix import errors in tests

* refactor: separate train/eval

* fix import errors

* eval checkpoints in reverse chron. order

* save/load cycle in sd init

* refactor and verify eval

* verify training correctness

* prepare repro train run

* cleanup

* integrate beam retry, train, eval

* simplify wandb

* kill orphaned processes

* better logging

* train to 10 ckpts instead of 7

* remove optimizer/scheduler checkpointing/resume

* cleanup

* BEAM=2 7 ckpts

* add test to compare with torch softmax in amp

* cleanup

* stop eval early if checkpoint converged

* add test for lr scheduler

* add proper test method

* add test for training

* use venv name that is ignored by .gitignore

* linting

* add simple f32 softmax fxn

* revert change to scaled_dot_product_attention

* refactor gelu_erf init

* simplify mixed precision in unet

* add norm autocasting to fp32

* rm extra test

* test eval with NULL backend

* fix venv name

* simplify norm autocast

* use temp dir for training test

* actually add eval test

* remove parallel env variable from tests

* update clip with tests

* reorg init functions

* use np for testing

* remove unused var

* factor out GPUS

* add sd model init tests

* more unet tests

* match master

* rerun CI due to linux (remote) hang

* explain UNET_CKPTDIR

* rerun CI due to linux (remote) timeout

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2025-10-05 07:56:05 -04:00
hooved
1e8945a28c Training loop for Stable Diffusion mlperf (#12315)
* add diff

* fix edit error

* match master

* point reference to specific commit

* simplify wandb logging

* remove lr test, dehardcode device

* increase stack size limit
2025-10-03 02:45:38 -04:00
hooved
5d9035f5a6 Eval for Stable Diffusion mlperf (#12316)
* add diff

* rerun ci

* refactor beam workaround, add test

* fix conflict

* linting
2025-10-02 02:35:38 -04:00
hooved
0f804c9a83 Stable Diffusion model init for mlperf (#12314)
* include clip pr diff

* updated unet and sd init

* dehardcode default device

* revert beam hang workaround

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2025-10-02 02:28:41 -04:00
hooved
969a1b35ca LR scheduler for Stable Diffusion mlperf training (#12201)
* add lr scheduler for stable diffusion training

* add lr scheduler test

* rerun ci

* rerun CI

* use np for testing

* move test to CI path

* remove unneeded copy
2025-09-30 21:21:08 -04:00
hooved
c2689c505e Clip model updates for Stable Diffusion mlperf training (#12313)
* stable diffusion mlperf clip changes

* add clip tests

* set gelu as attribute

* add more tests

* factor out GPUS

* rerun CI

* add imports to if blocks

* remove unneeded axis

* add clip tests to CI

* move clip tests

* add deps, disable max buf size
2025-09-29 21:50:14 -04:00
hooved
3a9db08b49 download data and ckpts for sd train/eval (#12170) 2025-09-15 00:31:45 -04:00
Sieds Lykles
5b73076e48 assert benchmark times (#12042)
* assert jitted times in openpilot

* better error

* better error

* add ASSERT_MIN_STEP_TIME to more models

* t is step_times

* update benchmark times

* update times
2025-09-09 23:40:02 +02:00
wozeparrot
d16cc6c012 feat: resume ckpt (#11970) 2025-09-02 15:47:48 -07:00
wozeparrot
7c21271a5f feat: end_lr envvar (#11953) 2025-09-01 14:53:07 -07:00
wozeparrot
7e68045fb2 feat: small llama3 training (#11829) 2025-08-31 13:41:47 -07:00
wozeparrot
b979162c5d llama3 eval train (#11706) 2025-08-20 19:56:35 -04:00
chenyu
dbd3b67657 clamp GRAD_CLIP_NORM in llama (#11761) 2025-08-20 19:55:50 -04:00
chenyu
e9d0027591 llama MP realize weight after shard (#11672)
* llama MP realize weight after shard

prevents memory spike on device 0

* empty weight for FAKEDATA
2025-08-14 16:17:46 -04:00
chenyu
ef17af85c6 remove .float call in llama logit (#11598)
* remove .float call in llama logit

* bfloat item
2025-08-10 00:02:18 -04:00
chenyu
45baec1aab model parallel llama (#11588)
MP=8 GRADIENT_ACC_STEPS=3 BS=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=70B SEQLEN=512 PYTHONPATH=. MODEL=llama3 python3 examples/mlperf/model_train.py
2025-08-09 16:54:27 -04:00
chenyu
702e38dc19 remove FUSE_ARANGE_UINT (#11567)
also add IGNORE_OOB=1 to bert runs. lowered BS on tinybox to 90 since 96 oom during eval without reset
2025-08-07 16:49:06 -04:00
wozeparrot
7ae4335127 feat: generate blend index (#11566) 2025-08-07 14:20:28 -04:00
wozeparrot
2d5bdc939d faster llama3 dataloader (#11540) 2025-08-06 18:25:57 -04:00
chenyu
f7965f85aa Revert "feat: faster index building (#11462)" (#11478)
This reverts commit 3a4deb08d2.
2025-08-02 12:50:48 -04:00
wozeparrot
3a4deb08d2 feat: faster index building (#11462)
* feat: faster index building

* feat: correct training samples
2025-08-02 11:50:18 -04:00
chenyu
9e8e6b45ab grad acc train llama (#11467)
* grad acc train llama

* log step time
2025-08-01 15:54:50 -04:00
chenyu
7ad7329257 data parallel train llama (#11466) 2025-08-01 12:13:51 -04:00
George Hotz
8ff03806e8 add llama layers (#11460)
* add llama layers

* add contig bw for speed
2025-07-31 16:28:04 -07:00
wozeparrot
6252f7770e feat: fake data (#11447) 2025-07-30 17:18:20 -07:00
chenyu
e300451f3a update llama3 (#11446)
`LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py` trained to 7
2025-07-30 19:34:21 -04:00
wozeparrot
5fb975351a feat: flag for training on val (#11441) 2025-07-30 14:29:45 -07:00
wozeparrot
825b6a2505 feat: llama3 dataloader (#11340) 2025-07-30 13:27:55 -07:00
chenyu
c14c9a8eff llama3 grad clip (#11003) 2025-06-27 19:14:12 -04:00
chenyu
f2548afeb5 bert grad clipping start with const 0 (#11008)
saved the init kernels
2025-06-27 18:02:23 -04:00
chenyu
6ab5a5cb6c llama3 mlperf train (#10983)
work in progress. now it can overfit small examples and vram roughly matches
2025-06-26 20:24:27 -04:00
chenyu
8751d47985 CosineAnnealingLRWithWarmup (#10981) 2025-06-25 17:45:21 -04:00
chenyu
efad567ebd ruff check whole examples/mlperf/ (#10979) 2025-06-25 12:57:48 -04:00
chenyu
0480139def log_perplexity metrics (#10912) 2025-06-21 10:44:47 -04:00
chenyu
62a540066e remove DEBUG=2 in mi300x bert setup (#10886)
seems fine now, not sure what the issue was
2025-06-19 13:28:53 -04:00
chenyu
f377cc19cd use AM for bert (#10882)
have triained 3 runs and all seem fine
2025-06-19 09:48:54 -04:00
chenyu
b70c7d3631 bert grad accumulation (#10863)
* bert grad accumulation

* realize grad
2025-06-18 12:17:07 -04:00
chenyu
075a74cf25 add global_batch_size to mlperf bert (#10852)
global_batch_size = grad_acc_steps * batch_size. no-op change to prep grad acc for bert
2025-06-17 17:54:15 -04:00
chenyu
81e296d7b8 remove Tensor.test() in retinanet (#10770)
test was removed
2025-06-10 22:14:57 -04:00
George Hotz
32e9949052 rename lazydata to uop (#10698) 2025-06-08 08:42:22 -07:00
chenyu
4ab3391e6f set -o pipefail for mlperf run_and_time (#10577)
also run the 5.1 script in ci cron job
2025-05-30 16:36:44 -04:00
chenyu
baf482d314 copy mlperf stuff to 5.1 (#10576)
5.0 is finalized, new changes go to 5.1
2025-05-30 16:12:39 -04:00
George Hotz
b3b43a82c4 remove Tensor.no_grad, it's meaningless now [pr] (#10556) 2025-05-28 22:20:02 -07:00
chenyu
74cf5dbd9e mlperf system updates (#10550)
standardized processor and accelerator names
2025-05-28 16:15:46 -04:00