Commit Graph

18 Commits

Author SHA1 Message Date
chenyu
041e9a41c9 add contiguous in BertIntermediate (#13713)
faster step with a lot less recomputation
2025-12-15 22:37:36 -05:00
chenyu
7dadbf3697 insert float() in bert acc (#9726)
sum of bool by default uses default_float for acc. So without float, it might overflow with a large BS and default_float=HALF.

fixed clsf_accuracy to not be inf in mi300x bert
2025-04-03 05:44:09 -04:00
chenyu
3ae66e59a3 least_upper_float is at least default_float (#9303)
* least_upper_float is at least default_float

en route for div rounding mode. dtype of true int division would change from int32 to default_float, which matches torch too.

* fix bert acc
2025-02-28 10:41:56 -05:00
chenyu
975c318dbc bert use int32 for input ids (#9173)
original data was int32 for these. float might have caused precision issues
2025-02-19 08:17:27 -05:00
chenyu
49b914ee69 simpler bert acc [pr] (#8714)
logit.log_softmax().argmax(-1) is equivalent to logit.argmax(-1)
2025-01-22 10:32:19 -05:00
chenyu
fb694a63eb Tensor.erf (#7419)
the same one used in onnx and the one in bert.
2024-10-30 18:12:28 -04:00
chenyu
01a2d7316d dtype=float in bert log_softmax for loss and accuracy (#6916) 2024-10-06 11:15:56 -04:00
chenyu
396c96357b update mlperf bert scripts (#6755)
removed DISABLE_DROPOUT=1.
updated BS to 54 that works on tinyboxes with dropouts.
used bert's sparse_categorical_crossentropy that takes Tensor ignore_index in accuracy method
2024-09-25 23:55:05 -04:00
chenyu
e6c7c3e499 update pylint path to check indent/space for all (#6022)
also fixed many errors. it was not checking nested dirs. exclude autogen for now.

can we use ruff for this?
2024-08-10 14:41:09 -04:00
Elias Wahl
d2e3c391e8 Residual in MLM loss + Change default steps (#4935)
* Residual in mlm loss

* Reduce default steps to 160K * 24

* oops

* comment
2024-06-12 16:09:18 -04:00
Elias Wahl
04e237328b Refactor to class style (#4804) 2024-06-04 14:08:31 -07:00
chenyu
31358cbea5 change Tensor.stack to method (#4719) 2024-05-24 17:04:19 -04:00
wozeparrot
d2c347fc74 faster gather for bert (#4526) 2024-05-10 22:28:48 -07:00
Elias Wahl
27613dd881 MLPerf BERT: Main training loop (#4288)
* BERT language modeling head + trunc normal initializers

* add train loop + helpers

* shuffle in dataloaders + slight changes in main loop

* beam change

* Minor changes

* random.shuffle

* HParam update

* Use deque for dataloader

* wandb bert project name

* half fixes

* BENCHMARK + remove epoch

* cast + print()

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-04-29 14:35:27 -04:00
Elias Wahl
2ecd61e3e2 monkey patching (#4214) 2024-04-18 19:20:52 -04:00
Elias Wahl
7db6dd725d multilazybuffer fix (#3609) 2024-03-04 17:36:23 -05:00
George Hotz
d87a246439 move to new cached fetch (#2493)
* move to new cached fetch

* extra.utils is over

* loads

* bump download cache

* bump timeout
2023-11-28 17:36:55 -08:00
George Hotz
0cbf6c1811 move things, clean up extra (#2292)
* move things

* idk why pylint needs that now

* delete unused
2023-11-13 20:18:40 -08:00