remove flatten and reshape in sparse_categorical_crossentropy [pr] (#11093)

not needed, directly operating on the classes dim is fine
This commit is contained in:
chenyu
2025-07-04 15:15:27 -04:00
committed by GitHub
parent 577afc9f05
commit 39b4d72687

View File

@@ -3920,9 +3920,9 @@ class Tensor(MathTrait):
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
assert reduction in get_args(ReductionStr), f"reduction must be one of {get_args(ReductionStr)}"
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1])
y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
log_probs = self.log_softmax()
loss_mask = (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
y = Y.to(self.device).unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1)
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)