more trivial stuff to mixin (#15693)

This commit is contained in:
chenyu
2026-04-12 15:17:16 -04:00
committed by GitHub
parent ff1de5ae13
commit 77385ccb37
2 changed files with 74 additions and 66 deletions

View File

@@ -1,10 +1,13 @@
from typing import Self, Sequence
import functools
from typing import Self, Sequence, Literal, get_args
from tinygrad.mixin.elementwise import ElementwiseMixin
from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop.ops import _broadcast_shape, resolve
from tinygrad.dtype import DTypeLike, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, prod
ReductionStr = Literal["mean", "sum", "none"]
class OpMixin(ElementwiseMixin, ReduceMixin):
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
@@ -306,3 +309,72 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
if weight is not None: x = x * weight.reshape(shape)
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
return (ret + bias.reshape(shape)) if bias is not None else ret
# ***** loss ops *****
def _do_reduction(self, reduction:ReductionStr="mean") -> Self:
if reduction == "none": return self
if reduction == "sum": return self.sum()
if reduction == "mean": return self.mean()
raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
def binary_crossentropy(self, Y:Self, reduction:ReductionStr="mean") -> Self:
"""
Computes the binary cross-entropy loss between `self` and `Y`.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0.1, 0.9, 0.2])
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy(Y).item())
```
"""
return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)
def binary_crossentropy_logits(self, Y:Self, reduction:ReductionStr="mean", pos_weight:Self|None=None) -> Self:
"""
Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, -3])
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy_logits(Y).item())
```
"""
log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid()
return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction)
# ***** matrix ops *****
def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Self:
"""
Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.randn(4, 4)
print(t.newton_schulz(steps=5, params=(2,-1.5,0.5)).numpy())
```
"""
assert self.ndim > 1, "NS only works for two or more dims"
if self.shape[-2] > self.shape[-1]: return self.transpose(-2, -1).newton_schulz(steps, params, eps).transpose(-2, -1)
G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps)
for _ in range(steps):
G = functools.reduce(lambda a, b: a + b, (p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) # type: ignore[operator]
for i,p in enumerate(params)))
return G
# ***** tensor properties *****
def nbytes(self) -> int:
"""
Returns the total number of bytes of all elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([8, 9], dtype=dtypes.float)
print(t.nbytes())
```
"""
return int(self.numel()) * self.element_size()

View File

@@ -10,7 +10,7 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u
from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
from tinygrad.mixin import OpMixin, ReductionStr
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
from tinygrad.uop.ops import _broadcast_shape
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
@@ -90,8 +90,6 @@ def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, .
# `(padding_left, padding_right, padding_top, padding_bottom, ...)` -> `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
def _flat_to_grouped(padding:Sequence[sint]) -> tuple[tuple[sint, sint], ...]: return tuple(zip(padding[-2::-2], padding[::-2]))
ReductionStr = Literal["mean", "sum", "none"]
class Tensor(OpMixin):
"""
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
@@ -2503,41 +2501,6 @@ class Tensor(OpMixin):
qk = qk + attn_mask
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
if reduction == "none": return self
if reduction == "sum": return self.sum()
if reduction == "mean": return self.mean()
raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `Y`.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0.1, 0.9, 0.2])
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy(Y).item())
```
"""
return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)
def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean", pos_weight:Tensor|None=None) -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, -3])
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy_logits(Y).item())
```
"""
log_p, log_1_minus_p = self.logsigmoid(), (-self).logsigmoid()
return (-((1 if pos_weight is None else pos_weight) * Y * log_p + (1-Y) * log_1_minus_p))._do_reduction(reduction)
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the sparse categorical cross-entropy loss between `self` and `Y`.
@@ -2612,22 +2575,6 @@ class Tensor(OpMixin):
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Tensor:
"""
Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.randn(4, 4)
print(t.newton_schulz(steps=5, params=(2,-1.5,0.5)).numpy())
```
"""
assert self.ndim > 1, "NS only works for two or more dims"
if self.shape[-2] > self.shape[-1]: return self.transpose(-2, -1).newton_schulz(steps, params, eps).transpose(-2, -1)
G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps)
for _ in range(steps):
G = cast(Tensor, sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params)))
return G
def qr(self) -> tuple[Tensor, Tensor]:
assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
b_shape, m, n = self.shape[:-2], int(self.shape[-2]), int(self.shape[-1])
@@ -2701,17 +2648,6 @@ class Tensor(OpMixin):
# ***** Tensor Properties *****
def nbytes(self) -> int:
"""
Returns the total number of bytes of all elements in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([8, 9], dtype=dtypes.float)
print(t.nbytes())
```
"""
return int(self.numel()) * self.element_size()
def size(self, dim:int|None=None) -> sint|tuple[sint, ...]:
"""
Returns the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.