Commit Graph

170 Commits

Author SHA1 Message Date
wozeparrot
4b5d3bda1f llama3: data seed (#14681) 2026-02-11 19:04:40 -08:00
wozeparrot
a60220bed9 llama3: move dl to numpy & jit more (#14677)
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2026-02-10 18:16:40 -08:00
wozeparrot
4845e42135 llama3 gradacc fixes (#14414) 2026-01-28 19:12:39 -08:00
nimlgen
aec1ae0de1 llama: set manual_seed (#14409) 2026-01-28 14:40:00 -08:00
George Hotz
0c6b3f50aa add marker to llama training (#14401) 2026-01-28 22:44:28 +08:00
wozeparrot
e496547720 llama3 gradacc (#14291) 2026-01-27 19:48:10 -08:00
wozeparrot
963c59ebdb fix: pull fixes from gradacc branch (#14296) 2026-01-22 23:07:54 -08:00
wozeparrot
c1d14ea832 llama8b train fixes (#14264) 2026-01-20 20:34:47 -08:00
b1tg
0fbc551622 train bert with fp8 (#13874)
* fp8 train

* clean

* lint

* test fix from #13439

* skip first/last layer

* rm __init__, restore unroll <=32 check

* tests

* clean test, remove unused

* multi-gpu test, clean quantize_to_fp8

* remove bert contiguous

* run script

* test: better check

* run script search

* add seed in bert data shuffle

* move script to mi350x folder

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2026-01-09 09:21:59 -05:00
b1tg
241f0402b4 add seed in bert data shuffle (#14054) 2026-01-07 10:02:05 -05:00
chenyu
da1cb6a9ec update llama dataloader (#13825)
separate creating dataset from itererating over the dataset to not create eval data for each eval
2025-12-24 17:42:08 -05:00
chenyu
903753c60c llama wandb logging (#13822) 2025-12-24 10:24:59 -05:00
chenyu
27d899ce97 TRAIN=0 to only eval llama (#13804) 2025-12-22 11:55:46 -05:00
chenyu
39d962106f update llama logging (#13803)
```
REWRITE_STACK_LIMIT=1000000 SMALL=1 BASEDIR=/raid/datasets/c4-8b SAMPLES=1000 BS=8 DP=8 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=8B SEQLEN=1024 PYTHONPATH=. MODEL=llama3 python3 examples/mlperf/model_train.py

    1 93.44 s run, 11.8750 loss, 0.000000000001 LR, 642.43 GB used,  19644.30 GFLOPS
    2 101.78 s run, 11.8750 loss, 0.000000000001 LR, 1454.57 GB used,  17039.35 GFLOPS
    3 7.34 s run, 11.8750 loss, 0.000000000002 LR, 1454.57 GB used, 236258.78 GFLOPS
    4 4.32 s run, 11.8750 loss, 0.000000000002 LR, 1454.57 GB used, 401488.40 GFLOPS
    5 4.36 s run, 11.9375 loss, 0.000000000003 LR, 1454.57 GB used, 398116.13 GFLOPS
    6 4.32 s run, 11.8750 loss, 0.000000000003 LR, 1454.57 GB used, 401878.60 GFLOPS
    7 4.34 s run, 11.8750 loss, 0.000000000004 LR, 1454.57 GB used, 399822.57 GFLOPS
    8 4.35 s run, 11.8750 loss, 0.000000000004 LR, 1454.57 GB used, 398512.24 GFLOPS
    9 4.36 s run, 11.8750 loss, 0.000000000005 LR, 1454.57 GB used, 397832.61 GFLOPS
   10 4.40 s run, 11.8750 loss, 0.000000000005 LR, 1454.57 GB used, 394520.83 GFLOPS
```
2025-12-22 11:28:29 -05:00
chenyu
e428fbfab6 verify dtype of llama model params (#13719) 2025-12-16 12:32:02 -05:00
chenyu
6cad622f59 don't FREE_INTERMEDIATE in bert (#13684)
hangs green hcq consistently after an hour of training
2025-12-14 14:27:42 -05:00
chenyu
01e9ad0d52 clean up bert next_data (#13650)
train iter was designed to never stop for both real and fake data
2025-12-11 22:56:28 -05:00
chenyu
5034c6fb37 reenable FREE_INTERMEDIATE for bert (#13639)
* reenable FREE_INTERMEDIATE for bert

* comment
2025-12-10 12:08:09 -05:00
chenyu
2471b49e45 minor bert / llama change from grad acc branch (#13622)
* minor bert / llama change from grad acc branch

* revert those
2025-12-08 16:04:14 -05:00
chenyu
b981b6f89e remove old llama grad_acc (#13611)
* remove old llama grad_acc

* GRADIENT_ACC_STEPS=1
2025-12-07 13:03:47 -05:00
chenyu
4562f217e1 more bert updates (#13597)
prep split jit
also lower BS to 72
2025-12-06 08:32:43 -05:00
chenyu
cb4c6324ef revert bert grad accumulation (#13596)
prep for the new split jit style
2025-12-05 17:30:08 -05:00
chenyu
77b5e6774e fix bert training config (#12647)
FREE_INTERMEDIATE=0 REWRITE_STACK_LIMIT=500000
2025-10-13 15:03:47 -04:00
chenyu
28edea5d67 delete FUSE_CONV_BW (#12527) 2025-10-08 10:41:38 -04:00
chenyu
e701106a64 remove FUSE_ARANGE (#12511)
it was the default already
2025-10-08 04:54:07 -04:00
hooved
1e8945a28c Training loop for Stable Diffusion mlperf (#12315)
* add diff

* fix edit error

* match master

* point reference to specific commit

* simplify wandb logging

* remove lr test, dehardcode device

* increase stack size limit
2025-10-03 02:45:38 -04:00
Sieds Lykles
5b73076e48 assert benchmark times (#12042)
* assert jitted times in openpilot

* better error

* better error

* add ASSERT_MIN_STEP_TIME to more models

* t is step_times

* update benchmark times

* update times
2025-09-09 23:40:02 +02:00
wozeparrot
d16cc6c012 feat: resume ckpt (#11970) 2025-09-02 15:47:48 -07:00
wozeparrot
7c21271a5f feat: end_lr envvar (#11953) 2025-09-01 14:53:07 -07:00
wozeparrot
7e68045fb2 feat: small llama3 training (#11829) 2025-08-31 13:41:47 -07:00
wozeparrot
b979162c5d llama3 eval train (#11706) 2025-08-20 19:56:35 -04:00
chenyu
dbd3b67657 clamp GRAD_CLIP_NORM in llama (#11761) 2025-08-20 19:55:50 -04:00
chenyu
e9d0027591 llama MP realize weight after shard (#11672)
* llama MP realize weight after shard

prevents memory spike on device 0

* empty weight for FAKEDATA
2025-08-14 16:17:46 -04:00
chenyu
ef17af85c6 remove .float call in llama logit (#11598)
* remove .float call in llama logit

* bfloat item
2025-08-10 00:02:18 -04:00
chenyu
45baec1aab model parallel llama (#11588)
MP=8 GRADIENT_ACC_STEPS=3 BS=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=70B SEQLEN=512 PYTHONPATH=. MODEL=llama3 python3 examples/mlperf/model_train.py
2025-08-09 16:54:27 -04:00
wozeparrot
2d5bdc939d faster llama3 dataloader (#11540) 2025-08-06 18:25:57 -04:00
chenyu
f7965f85aa Revert "feat: faster index building (#11462)" (#11478)
This reverts commit 3a4deb08d2.
2025-08-02 12:50:48 -04:00
wozeparrot
3a4deb08d2 feat: faster index building (#11462)
* feat: faster index building

* feat: correct training samples
2025-08-02 11:50:18 -04:00
chenyu
9e8e6b45ab grad acc train llama (#11467)
* grad acc train llama

* log step time
2025-08-01 15:54:50 -04:00
chenyu
7ad7329257 data parallel train llama (#11466) 2025-08-01 12:13:51 -04:00
George Hotz
8ff03806e8 add llama layers (#11460)
* add llama layers

* add contig bw for speed
2025-07-31 16:28:04 -07:00
wozeparrot
6252f7770e feat: fake data (#11447) 2025-07-30 17:18:20 -07:00
chenyu
e300451f3a update llama3 (#11446)
`LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py` trained to 7
2025-07-30 19:34:21 -04:00
wozeparrot
5fb975351a feat: flag for training on val (#11441) 2025-07-30 14:29:45 -07:00
wozeparrot
825b6a2505 feat: llama3 dataloader (#11340) 2025-07-30 13:27:55 -07:00
chenyu
c14c9a8eff llama3 grad clip (#11003) 2025-06-27 19:14:12 -04:00
chenyu
f2548afeb5 bert grad clipping start with const 0 (#11008)
saved the init kernels
2025-06-27 18:02:23 -04:00
chenyu
6ab5a5cb6c llama3 mlperf train (#10983)
work in progress. now it can overfit small examples and vram roughly matches
2025-06-26 20:24:27 -04:00
chenyu
b70c7d3631 bert grad accumulation (#10863)
* bert grad accumulation

* realize grad
2025-06-18 12:17:07 -04:00
chenyu
075a74cf25 add global_batch_size to mlperf bert (#10852)
global_batch_size = grad_acc_steps * batch_size. no-op change to prep grad acc for bert
2025-06-17 17:54:15 -04:00