chenyu
2471b49e45
minor bert / llama change from grad acc branch ( #13622 )
...
* minor bert / llama change from grad acc branch
* revert those
2025-12-08 16:04:14 -05:00
chenyu
b981b6f89e
remove old llama grad_acc ( #13611 )
...
* remove old llama grad_acc
* GRADIENT_ACC_STEPS=1
2025-12-07 13:03:47 -05:00
chenyu
4562f217e1
more bert updates ( #13597 )
...
prep split jit
also lower BS to 72
2025-12-06 08:32:43 -05:00
chenyu
cb4c6324ef
revert bert grad accumulation ( #13596 )
...
prep for the new split jit style
2025-12-05 17:30:08 -05:00
chenyu
77b5e6774e
fix bert training config ( #12647 )
...
FREE_INTERMEDIATE=0 REWRITE_STACK_LIMIT=500000
2025-10-13 15:03:47 -04:00
chenyu
28edea5d67
delete FUSE_CONV_BW ( #12527 )
2025-10-08 10:41:38 -04:00
chenyu
e701106a64
remove FUSE_ARANGE ( #12511 )
...
it was the default already
2025-10-08 04:54:07 -04:00
hooved
1e8945a28c
Training loop for Stable Diffusion mlperf ( #12315 )
...
* add diff
* fix edit error
* match master
* point reference to specific commit
* simplify wandb logging
* remove lr test, dehardcode device
* increase stack size limit
2025-10-03 02:45:38 -04:00
Sieds Lykles
5b73076e48
assert benchmark times ( #12042 )
...
* assert jitted times in openpilot
* better error
* better error
* add ASSERT_MIN_STEP_TIME to more models
* t is step_times
* update benchmark times
* update times
2025-09-09 23:40:02 +02:00
wozeparrot
d16cc6c012
feat: resume ckpt ( #11970 )
2025-09-02 15:47:48 -07:00
wozeparrot
7c21271a5f
feat: end_lr envvar ( #11953 )
2025-09-01 14:53:07 -07:00
wozeparrot
7e68045fb2
feat: small llama3 training ( #11829 )
2025-08-31 13:41:47 -07:00
wozeparrot
b979162c5d
llama3 eval train ( #11706 )
2025-08-20 19:56:35 -04:00
chenyu
dbd3b67657
clamp GRAD_CLIP_NORM in llama ( #11761 )
2025-08-20 19:55:50 -04:00
chenyu
e9d0027591
llama MP realize weight after shard ( #11672 )
...
* llama MP realize weight after shard
prevents memory spike on device 0
* empty weight for FAKEDATA
2025-08-14 16:17:46 -04:00
chenyu
ef17af85c6
remove .float call in llama logit ( #11598 )
...
* remove .float call in llama logit
* bfloat item
2025-08-10 00:02:18 -04:00
chenyu
45baec1aab
model parallel llama ( #11588 )
...
MP=8 GRADIENT_ACC_STEPS=3 BS=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=70B SEQLEN=512 PYTHONPATH=. MODEL=llama3 python3 examples/mlperf/model_train.py
2025-08-09 16:54:27 -04:00
wozeparrot
2d5bdc939d
faster llama3 dataloader ( #11540 )
2025-08-06 18:25:57 -04:00
chenyu
f7965f85aa
Revert "feat: faster index building ( #11462 )" ( #11478 )
...
This reverts commit 3a4deb08d2 .
2025-08-02 12:50:48 -04:00
wozeparrot
3a4deb08d2
feat: faster index building ( #11462 )
...
* feat: faster index building
* feat: correct training samples
2025-08-02 11:50:18 -04:00
chenyu
9e8e6b45ab
grad acc train llama ( #11467 )
...
* grad acc train llama
* log step time
2025-08-01 15:54:50 -04:00
chenyu
7ad7329257
data parallel train llama ( #11466 )
2025-08-01 12:13:51 -04:00
George Hotz
8ff03806e8
add llama layers ( #11460 )
...
* add llama layers
* add contig bw for speed
2025-07-31 16:28:04 -07:00
wozeparrot
6252f7770e
feat: fake data ( #11447 )
2025-07-30 17:18:20 -07:00
chenyu
e300451f3a
update llama3 ( #11446 )
...
`LR=1e-4 TRAIN_ON_VAL=1 DEFAULT_FLOAT=bfloat16 FUSE_ARANGE=1 JITBEAM=2 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B WARMUP_STEPS=36 DECAY_STEPS=360 SEQLEN=512 PYTHONPATH=. AMD=1 AMD_LLVM=0 MODEL=llama3 python3 examples/mlperf/model_train.py` trained to 7
2025-07-30 19:34:21 -04:00
wozeparrot
5fb975351a
feat: flag for training on val ( #11441 )
2025-07-30 14:29:45 -07:00
wozeparrot
825b6a2505
feat: llama3 dataloader ( #11340 )
2025-07-30 13:27:55 -07:00
chenyu
c14c9a8eff
llama3 grad clip ( #11003 )
2025-06-27 19:14:12 -04:00
chenyu
f2548afeb5
bert grad clipping start with const 0 ( #11008 )
...
saved the init kernels
2025-06-27 18:02:23 -04:00
chenyu
6ab5a5cb6c
llama3 mlperf train ( #10983 )
...
work in progress. now it can overfit small examples and vram roughly matches
2025-06-26 20:24:27 -04:00
chenyu
b70c7d3631
bert grad accumulation ( #10863 )
...
* bert grad accumulation
* realize grad
2025-06-18 12:17:07 -04:00
chenyu
075a74cf25
add global_batch_size to mlperf bert ( #10852 )
...
global_batch_size = grad_acc_steps * batch_size. no-op change to prep grad acc for bert
2025-06-17 17:54:15 -04:00
chenyu
81e296d7b8
remove Tensor.test() in retinanet ( #10770 )
...
test was removed
2025-06-10 22:14:57 -04:00
George Hotz
b3b43a82c4
remove Tensor.no_grad, it's meaningless now [pr] ( #10556 )
2025-05-28 22:20:02 -07:00
chenyu
dc6309242d
WallTimeEvent for mlperf ci ( #10506 )
2025-05-24 10:56:03 -04:00
chenyu
485e80da69
run_and_time for resnet ci ( #10405 )
2025-05-18 23:39:57 -04:00
wozeparrot
1ed04f993b
move benchmark stat tracking to influxdb ( #10185 )
2025-05-15 16:14:56 -07:00
chenyu
610ee79b22
cherry pick mlperf5.0 branch to master ( #10089 )
2025-04-28 15:36:56 -04:00
chenyu
74c6cf8be3
lint mlperf model_train ( #10038 )
2025-04-24 16:19:44 -04:00
chenyu
a25abf55e3
retinanet only call postprocess_detections with RUNMLPERF ( #10017 )
...
during setup only need to compile `_eval_step().numpy()`
2025-04-23 20:45:38 -04:00
chenyu
a3f938dbee
remove retinanet INITMLPERF from beam script ( #10011 )
...
it only controls logging, loading real data or not is solely controlled by RUNMLPERF
2025-04-23 14:32:54 -04:00
Francis Lata
5542aeb0e4
RetinaNet MLPerf flag updates ( #10009 )
...
* add RUNMLPERF and update INITMLPERF usage
* update scripts to use RUNMLPERF
2025-04-23 13:00:34 -04:00
George Hotz
de0504276b
pop 0 is slow [pr] ( #10007 )
2025-04-23 17:00:59 +01:00
chenyu
c39128133c
retinanet green scripts ( #9996 )
...
also removed realize in data_get and used empty for fake data. slightly bigger lr. https://wandb.ai/chenyuxyz/MLPerf-RetinaNet/runs/8skid0e8?nw=nwuserchenyuxyz
2025-04-23 08:28:03 -04:00
chenyu
fb89d9a584
retinanet eval combine output on GPUS[0] ( #9966 )
...
eval 35 sec -> 20 sec. it was spending 13 seconds assembling output tensor on CPU backend. GPUS[0] seems to have enough memory, otherwise we can lower EVAL_BS
2025-04-22 07:43:51 -04:00
chenyu
5294c32279
dev scripts for retinanet ( #9968 )
...
also BASE_DIR -> BASEDIR for consistency, and move wandb up a bit for more accurate timing
2025-04-21 17:54:56 -04:00
Francis Lata
defa1e77f6
get the proper dataset count ( #9962 )
2025-04-21 12:11:37 -04:00
Francis Lata
d7e247f329
RetinaNet INITMLPERF support ( #9950 )
...
* fixes to make fake data work
* fix eval beam
* fix merge issue
2025-04-21 10:32:05 -04:00
Francis Lata
ea4cb2c715
small cleanups ( #9947 )
2025-04-20 20:33:20 -04:00
chenyu
e8024c8281
faster bert global_norm ( #9901 )
...
tinyamd 2% faster. also updated beam params that's 2-3% faster.
update mlperf doc and steps too
2025-04-15 18:24:44 -04:00