cosmetic Tensor._do_reduction cleanups (#15357)

This commit is contained in:
chenyu
2026-03-18 22:27:50 -04:00
committed by GitHub
parent 6aebf95dac
commit e407ee410c

View File

@@ -3351,9 +3351,10 @@ class Tensor(OpMixin):
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
return reductions[reduction](self)
if reduction == "none": return self
if reduction == "sum": return self.sum()
if reduction == "mean": return self.mean()
raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
"""
@@ -3400,14 +3401,12 @@ class Tensor(OpMixin):
```
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
assert reduction in get_args(ReductionStr), f"reduction must be one of {get_args(ReductionStr)}"
log_probs = self.log_softmax()
loss_mask = (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
y = Y.to(self.device).unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1)
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced))
return -unreduced.sum() / loss_mask.sum() if reduction == "mean" else -unreduced._do_reduction(reduction)
def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
"""
@@ -3455,7 +3454,7 @@ class Tensor(OpMixin):
print(t.log_softmax().nll_loss(Y, reduction='none').numpy())
```
"""
weight = Tensor.ones_like(Y, requires_grad=False) if weight is None else weight[Y]
weight = Y.ones_like(requires_grad=False) if weight is None else weight[Y]
masked_weight = weight if ignore_index is None else weight * (Y != ignore_index)
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)