diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 35a308cb06..51c7321e09 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3351,9 +3351,10 @@ class Tensor(OpMixin): return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor: - if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}") - reductions: dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x} - return reductions[reduction](self) + if reduction == "none": return self + if reduction == "sum": return self.sum() + if reduction == "mean": return self.mean() + raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}") def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor: """ @@ -3400,14 +3401,12 @@ class Tensor(OpMixin): ``` """ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]" - assert reduction in get_args(ReductionStr), f"reduction must be one of {get_args(ReductionStr)}" log_probs = self.log_softmax() loss_mask = (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool) y = Y.to(self.device).unsqueeze(-1)._one_hot_along_dim(self.shape[-1], dim=-1) * loss_mask.unsqueeze(-1) smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask) unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing) - # NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here) - return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced)) + return -unreduced.sum() / loss_mask.sum() if reduction == "mean" else -unreduced._do_reduction(reduction) def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor: """ @@ -3455,7 +3454,7 @@ class Tensor(OpMixin): print(t.log_softmax().nll_loss(Y, reduction='none').numpy()) ``` """ - weight = Tensor.ones_like(Y, requires_grad=False) if weight is None else weight[Y] + weight = Y.ones_like(requires_grad=False) if weight is None else weight[Y] masked_weight = weight if ignore_index is None else weight * (Y != ignore_index) nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)