mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cosmetic Tensor._do_reduction cleanups (#15357)
This commit is contained in:
@@ -3351,9 +3351,10 @@ class Tensor(OpMixin):
|
||||
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
|
||||
|
||||
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
|
||||
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
|
||||
reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
|
||||
return reductions[reduction](self)
|
||||
if reduction == "none": return self
|
||||
if reduction == "sum": return self.sum()
|
||||
if reduction == "mean": return self.mean()
|
||||
raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
|
||||
|
||||
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
|
||||
"""
|
||||
@@ -3400,14 +3401,12 @@ class Tensor(OpMixin):
|
||||
```
|
||||
"""
|
||||
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
||||
assert reduction in get_args(ReductionStr), f"reduction must be one of {get_args(ReductionStr)}"
|
||||
log_probs = self.log_softmax()
|
||||
loss_mask = (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
|
||||
y = Y.to(self.device).unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1)
|
||||
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
||||
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
|
||||
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
|
||||
return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced))
|
||||
return -unreduced.sum() / loss_mask.sum() if reduction == "mean" else -unreduced._do_reduction(reduction)
|
||||
|
||||
def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
|
||||
"""
|
||||
@@ -3455,7 +3454,7 @@ class Tensor(OpMixin):
|
||||
print(t.log_softmax().nll_loss(Y, reduction='none').numpy())
|
||||
```
|
||||
"""
|
||||
weight = Tensor.ones_like(Y, requires_grad=False) if weight is None else weight[Y]
|
||||
weight = Y.ones_like(requires_grad=False) if weight is None else weight[Y]
|
||||
masked_weight = weight if ignore_index is None else weight * (Y != ignore_index)
|
||||
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
|
||||
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
|
||||
|
||||
Reference in New Issue
Block a user