chenyu
7dadbf3697
insert float() in bert acc ( #9726 )
...
sum of bool by default uses default_float for acc. So without float, it might overflow with a large BS and default_float=HALF.
fixed clsf_accuracy to not be inf in mi300x bert
2025-04-03 05:44:09 -04:00
chenyu
3ae66e59a3
least_upper_float is at least default_float ( #9303 )
...
* least_upper_float is at least default_float
en route for div rounding mode. dtype of true int division would change from int32 to default_float, which matches torch too.
* fix bert acc
2025-02-28 10:41:56 -05:00
chenyu
975c318dbc
bert use int32 for input ids ( #9173 )
...
original data was int32 for these. float might have caused precision issues
2025-02-19 08:17:27 -05:00
chenyu
49b914ee69
simpler bert acc [pr] ( #8714 )
...
logit.log_softmax().argmax(-1) is equivalent to logit.argmax(-1)
2025-01-22 10:32:19 -05:00
chenyu
fb694a63eb
Tensor.erf ( #7419 )
...
the same one used in onnx and the one in bert.
2024-10-30 18:12:28 -04:00
chenyu
01a2d7316d
dtype=float in bert log_softmax for loss and accuracy ( #6916 )
2024-10-06 11:15:56 -04:00
chenyu
396c96357b
update mlperf bert scripts ( #6755 )
...
removed DISABLE_DROPOUT=1.
updated BS to 54 that works on tinyboxes with dropouts.
used bert's sparse_categorical_crossentropy that takes Tensor ignore_index in accuracy method
2024-09-25 23:55:05 -04:00
chenyu
e6c7c3e499
update pylint path to check indent/space for all ( #6022 )
...
also fixed many errors. it was not checking nested dirs. exclude autogen for now.
can we use ruff for this?
2024-08-10 14:41:09 -04:00
Elias Wahl
d2e3c391e8
Residual in MLM loss + Change default steps ( #4935 )
...
* Residual in mlm loss
* Reduce default steps to 160K * 24
* oops
* comment
2024-06-12 16:09:18 -04:00
Elias Wahl
04e237328b
Refactor to class style ( #4804 )
2024-06-04 14:08:31 -07:00
chenyu
31358cbea5
change Tensor.stack to method ( #4719 )
2024-05-24 17:04:19 -04:00
wozeparrot
d2c347fc74
faster gather for bert ( #4526 )
2024-05-10 22:28:48 -07:00
Elias Wahl
27613dd881
MLPerf BERT: Main training loop ( #4288 )
...
* BERT language modeling head + trunc normal initializers
* add train loop + helpers
* shuffle in dataloaders + slight changes in main loop
* beam change
* Minor changes
* random.shuffle
* HParam update
* Use deque for dataloader
* wandb bert project name
* half fixes
* BENCHMARK + remove epoch
* cast + print()
---------
Co-authored-by: chenyu <chenyu@fastmail.com >
2024-04-29 14:35:27 -04:00
Elias Wahl
2ecd61e3e2
monkey patching ( #4214 )
2024-04-18 19:20:52 -04:00
Elias Wahl
7db6dd725d
multilazybuffer fix ( #3609 )
2024-03-04 17:36:23 -05:00
George Hotz
d87a246439
move to new cached fetch ( #2493 )
...
* move to new cached fetch
* extra.utils is over
* loads
* bump download cache
* bump timeout
2023-11-28 17:36:55 -08:00
George Hotz
0cbf6c1811
move things, clean up extra ( #2292 )
...
* move things
* idk why pylint needs that now
* delete unused
2023-11-13 20:18:40 -08:00