diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index 83a2db99e7..38b8228042 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -1,10 +1,13 @@ -from typing import Self, Sequence +import functools +from typing import Self, Sequence, Literal, get_args from tinygrad.mixin.elementwise import ElementwiseMixin from tinygrad.mixin.reduce import ReduceMixin from tinygrad.uop.ops import _broadcast_shape, resolve from tinygrad.dtype import DTypeLike, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype from tinygrad.helpers import argfix, prod +ReductionStr = Literal["mean", "sum", "none"] + class OpMixin(ElementwiseMixin, ReduceMixin): def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]: @@ -306,3 +309,72 @@ class OpMixin(ElementwiseMixin, ReduceMixin): if weight is not None: x = x * weight.reshape(shape) ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd) return (ret + bias.reshape(shape)) if bias is not None else ret + + # ***** loss ops ***** + + def _do_reduction(self, reduction:ReductionStr="mean") -> Self: + if reduction == "none": return self + if reduction == "sum": return self.sum() + if reduction == "mean": return self.mean() + raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}") + + def binary_crossentropy(self, Y:Self, reduction:ReductionStr="mean") -> Self: + """ + Computes the binary cross-entropy loss between `self` and `Y`. + + See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([0.1, 0.9, 0.2]) + Y = Tensor([0, 1, 0]) + print(t.binary_crossentropy(Y).item()) + ``` + """ + return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction) + + def binary_crossentropy_logits(self, Y:Self, reduction:ReductionStr="mean", pos_weight:Self|None=None) -> Self: + """ + Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits. + + See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([-1, 2, -3]) + Y = Tensor([0, 1, 0]) + print(t.binary_crossentropy_logits(Y).item()) + ``` + """ + log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid() + return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction) + + # ***** matrix ops ***** + + def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Self: + """ + Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.randn(4, 4) + print(t.newton_schulz(steps=5, params=(2,-1.5,0.5)).numpy()) + ``` + """ + assert self.ndim > 1, "NS only works for two or more dims" + if self.shape[-2] > self.shape[-1]: return self.transpose(-2, -1).newton_schulz(steps, params, eps).transpose(-2, -1) + G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps) + for _ in range(steps): + G = functools.reduce(lambda a, b: a + b, (p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) # type: ignore[operator] + for i,p in enumerate(params))) + return G + + # ***** tensor properties ***** + + def nbytes(self) -> int: + """ + Returns the total number of bytes of all elements in the tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([8, 9], dtype=dtypes.float) + print(t.nbytes()) + ``` + """ + return int(self.numel()) * self.element_size() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8aecaa5cc5..23e2e0b83e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -10,7 +10,7 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.gradient import compute_gradient -from tinygrad.mixin import OpMixin +from tinygrad.mixin import OpMixin, ReductionStr from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable from tinygrad.uop.ops import _broadcast_shape from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars @@ -90,8 +90,6 @@ def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, . # `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))` def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2])) -ReductionStr = Literal["mean", "sum", "none"] - class Tensor(OpMixin): """ A `Tensor` is a multi-dimensional matrix containing elements of a single data type. @@ -2503,41 +2501,6 @@ class Tensor(OpMixin): qk = qk + attn_mask return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value - def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor: - if reduction == "none": return self - if reduction == "sum": return self.sum() - if reduction == "mean": return self.mean() - raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}") - - def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor: - """ - Computes the binary cross-entropy loss between `self` and `Y`. - - See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([0.1, 0.9, 0.2]) - Y = Tensor([0, 1, 0]) - print(t.binary_crossentropy(Y).item()) - ``` - """ - return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction) - - def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean", pos_weight:Tensor|None=None) -> Tensor: - """ - Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits. - - See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([-1, 2, -3]) - Y = Tensor([0, 1, 0]) - print(t.binary_crossentropy_logits(Y).item()) - ``` - """ - log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid() - return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction) - def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor: """ Computes the sparse categorical cross-entropy loss between `self` and `Y`. @@ -2612,22 +2575,6 @@ class Tensor(OpMixin): nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction) - def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Tensor: - """ - Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.randn(4, 4) - print(t.newton_schulz(steps=5, params=(2,-1.5,0.5)).numpy()) - ``` - """ - assert self.ndim > 1, "NS only works for two or more dims" - if self.shape[-2] > self.shape[-1]: return self.transpose(-2, -1).newton_schulz(steps, params, eps).transpose(-2, -1) - G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps) - for _ in range(steps): - G = cast(Tensor, sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params))) - return G - def qr(self) -> tuple[Tensor, Tensor]: assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}" b_shape, m, n = self.shape[:-2], int(self.shape[-2]), int(self.shape[-1]) @@ -2701,17 +2648,6 @@ class Tensor(OpMixin): # ***** Tensor Properties ***** - def nbytes(self) -> int: - """ - Returns the total number of bytes of all elements in the tensor. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([8, 9], dtype=dtypes.float) - print(t.nbytes()) - ``` - """ - return int(self.numel()) * self.element_size() - def size(self, dim:int|None=None) -> sint|tuple[sint, ...]: """ Returns the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.