cleanups on losses and dataset tests (#9538)

This commit is contained in:
Francis Lata
2025-03-21 17:03:18 -04:00
committed by GitHub
parent 8cbe4009fc
commit 1a1087e3a0
2 changed files with 11 additions and 11 deletions

View File

@@ -7,18 +7,18 @@ def dice_ce_loss(pred, tgt):
return (dice + ce) / 2
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float=0.25, gamma:float=2.0, reduction:str="none") -> Tensor:
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
p_t = p * tgt + (1 - p) * (1 - tgt)
loss = ce_loss * ((1 - p_t) ** gamma)
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
p_t = p * tgt + (1 - p) * (1 - tgt)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
loss = loss * alpha_t
if alpha >= 0:
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
loss = loss * alpha_t
if reduction == "mean": loss = loss.mean()
elif reduction == "sum": loss = loss.sum()
return loss
if reduction == "mean": loss = loss.mean()
elif reduction == "sum": loss = loss.sum()
return loss
def l1_loss(pred:Tensor, tgt:Tensor, reduction:str="none") -> Tensor:
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"

View File

@@ -60,7 +60,7 @@ class TestKiTS19Dataset(ExternalTestDatasets):
if use_old_dataloader:
dataset = iterate(list(Path(tempfile.gettempdir()).glob("case_*")), preprocessed_dir=preproc_pth, val=val, shuffle=shuffle, bs=batch_size)
else:
dataset = iter(batch_load_unet3d(preproc_pth, batch_size=batch_size, val=val, shuffle=shuffle, seed=seed))
dataset = batch_load_unet3d(preproc_pth, batch_size=batch_size, val=val, shuffle=shuffle, seed=seed)
return iter(dataset)