mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
* add DICE loss and metrics * update dice to include reference implementation's link * remove unused imports * remove unnecessary test file and update pred + label for metrics and losses test * add tests to CI + add exclusion of mlperf_unet3d --------- Co-authored-by: chenyu <chenyu@fastmail.com>
7 lines
265 B
Python
7 lines
265 B
Python
from examples.mlperf.metrics import dice_score
|
|
|
|
def dice_ce_loss(pred, tgt):
|
|
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
|
|
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
|
|
return (dice + ce) / 2
|