refactor sparse_categorical_crossentropy (#4406)

factor out the -1 * and / loss_mask.sum() for both smoothing and non-smoothing terms
This commit is contained in:
chenyu
2024-05-03 14:28:36 -04:00
committed by GitHub
parent 3401734e54
commit c7368515d2

View File

@@ -1324,9 +1324,9 @@ class Tensor:
# NOTE: self is a logits input
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
smoothing = -1 * label_smoothing * (log_probs.mean(-1) * loss_mask).sum() / loss_mask.sum()
return (1 - label_smoothing) * (log_probs * y).sum() / loss_mask.sum() + smoothing
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
# ***** cast ops *****