mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
revert losses changes
This commit is contained in:
@@ -7,18 +7,18 @@ def dice_ce_loss(pred, tgt):
|
||||
return (dice + ce) / 2
|
||||
|
||||
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float=0.25, gamma:float=2.0, reduction:str="none") -> Tensor:
|
||||
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
|
||||
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
|
||||
p_t = p * tgt + (1 - p) * (1 - tgt)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
|
||||
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
|
||||
p_t = p * tgt + (1 - p) * (1 - tgt)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
|
||||
loss = loss * alpha_t
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
|
||||
loss = loss * alpha_t
|
||||
|
||||
if reduction == "mean": loss = loss.mean()
|
||||
elif reduction == "sum": loss = loss.sum()
|
||||
return loss
|
||||
if reduction == "mean": loss = loss.mean()
|
||||
elif reduction == "sum": loss = loss.sum()
|
||||
return loss
|
||||
|
||||
def l1_loss(pred:Tensor, tgt:Tensor, reduction:str="none") -> Tensor:
|
||||
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
|
||||
|
||||
Reference in New Issue
Block a user