mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
refactor sparse_categorical_crossentropy (#4406)
factor out the -1 * and / loss_mask.sum() for both smoothing and non-smoothing terms
This commit is contained in:
@@ -1324,9 +1324,9 @@ class Tensor:
|
||||
# NOTE: self is a logits input
|
||||
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
|
||||
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
smoothing = -1 * label_smoothing * (log_probs.mean(-1) * loss_mask).sum() / loss_mask.sum()
|
||||
return (1 - label_smoothing) * (log_probs * y).sum() / loss_mask.sum() + smoothing
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
|
||||
return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user