move some methods to mixins (#13725)

* move some methods to mixins

* a few more

* math trunc
This commit is contained in:
George Hotz
2025-12-16 19:20:04 -04:00
committed by GitHub
parent c6ba016da6
commit 0fb645cc4c
2 changed files with 239 additions and 242 deletions

View File

@@ -1,3 +1,4 @@
import math
from typing import Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType
@@ -258,6 +259,13 @@ class MathMixin:
return self.alu(Ops.RECIPROCAL)
def trunc(self):
"""
Truncates the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
```
"""
return self.alu(Ops.TRUNC)
def sqrt(self):
@@ -277,3 +285,232 @@ class MathMixin:
def __pow__(self, x: Self | ConstType):
return self.pow(x)
def square(self):
"""
Squares the tensor element-wise.
Equivalent to `self*self`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
```
"""
return self * self
def clamp(self, min_=None, max_=None):
"""
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
```
"""
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
ret = (self < min_).where(min_, self) if min_ is not None else self
return (ret > max_).where(max_, ret) if max_ is not None else ret
def clip(self, min_=None, max_=None):
"""Alias for `Tensor.clamp`."""
return self.clamp(min_, max_)
def isnan(self):
"""
Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan().numpy())
```
"""
return self != self
def isinf(self, detect_positive: bool = True, detect_negative: bool = True):
"""
Checks the tensor element-wise to return True where the element is infinity, otherwise returns False
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf().numpy())
```
"""
return self.eq(float("inf")) * detect_positive + self.eq(float("-inf")) * detect_negative
def isfinite(self):
"""
Checks the tensor element-wise to return True where the element is finite, otherwise returns False
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite().numpy())
```
"""
return (self.isinf() | self.isnan()).logical_not()
def ceil(self):
"""
Rounds the tensor element-wise towards positive infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy())
```
"""
return (self > (b := self.trunc())).where(b+1, b)
def floor(self):
"""
Rounds the tensor element-wise towards negative infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy())
```
"""
return (self < (b := self.trunc())).where(b-1, b)
def relu(self):
"""
Applies the Rectified Linear Unit (ReLU) function element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
```
"""
# NOTE: if you write this as self.maximum(0) the gradient is wrong, passing through half when self is 0
return (self > 0).where(self, 0)
def sigmoid(self):
"""
Applies the Sigmoid function element-wise.
- Described: https://en.wikipedia.org/wiki/Sigmoid_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
```
"""
return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
def relu6(self):
"""
Applies the ReLU6 function element-wise.
- Paper: https://arxiv.org/abs/1704.04861v1
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
```
"""
return self.relu() - (self-6).relu()
def hardswish(self):
"""
Applies the Hardswish function element-wise.
- Paper: https://arxiv.org/abs/1905.02244v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
```
"""
return self * (self+3).relu6() * (1/6)
def hardsigmoid(self, alpha: float = 1/6, beta: float = 0.5):
"""
Applies the Hardsigmoid function element-wise.
NOTE: default `alpha` and `beta` values are taken from torch
- See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardsigmoid().numpy())
```
"""
return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
def hardtanh(self, min_val=-1, max_val=1):
"""
Applies the Hardtanh function element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
```
"""
return self.clip(min_val, max_val)
def leaky_relu(self, neg_slope=0.01):
"""
Applies the Leaky ReLU function element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
```
"""
return (self < 0).where(neg_slope*self, self)
def tanh(self):
"""
Applies the Hyperbolic Tangent (tanh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
```
"""
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def quick_gelu(self):
"""
Applies the Sigmoid GELU approximation element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
```
"""
return self * (self * 1.702).sigmoid()
def gelu(self):
"""
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
- Paper: https://arxiv.org/abs/1606.08415v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
```
"""
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
def swish(self):
"""
See `.silu()`
- Paper: https://arxiv.org/abs/1710.05941v1
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
```
"""
return self * self.sigmoid()
def silu(self):
"""
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
- Paper: https://arxiv.org/abs/1606.08415
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
```
"""
return self.swish() # The SiLU function is also known as the swish function.
def rsqrt(self):
"""
Computes the reciprocal of the square root of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
```
"""
return self.sqrt().reciprocal()

View File

@@ -190,8 +190,9 @@ class Tensor(OpMixin):
lhs,rhs = self._broadcasted(x, reverse)
return lhs._apply_uop(fxn, rhs)
# _binop is used by MathTrait
# _binop and alu are used by MathMixin
def _binop(self, op, x, reverse): return self._apply_broadcasted_uop(lambda *u: UOp.alu(u[0], op, *u[1:]), x, reverse)
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
def requires_grad_(self, requires_grad=True) -> Tensor:
self.requires_grad = requires_grad
@@ -2790,29 +2791,6 @@ class Tensor(OpMixin):
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
def relu(self) -> Tensor:
"""
Applies the Rectified Linear Unit (ReLU) function element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
```
"""
# NOTE: if you write this as self.maximum(0) the gradient is wrong, passing through half when self is 0
return (self>0).where(self, 0)
def sigmoid(self) -> Tensor:
"""
Applies the Sigmoid function element-wise.
- Described: https://en.wikipedia.org/wiki/Sigmoid_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
```
"""
return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
def logsigmoid(self) -> Tensor:
"""
Applies the LogSigmoid function element-wise.
@@ -2825,19 +2803,6 @@ class Tensor(OpMixin):
"""
return -(-self).softplus()
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5) -> Tensor:
"""
Applies the Hardsigmoid function element-wise.
NOTE: default `alpha` and `beta` values are taken from torch
- See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardsigmoid().numpy())
```
"""
return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
def sqrt(self) -> Tensor:
"""
Computes the square root of the tensor element-wise.
@@ -2848,16 +2813,6 @@ class Tensor(OpMixin):
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
def rsqrt(self) -> Tensor:
"""
Computes the reciprocal of the square root of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
```
"""
return self.sqrt().reciprocal()
def sin(self) -> Tensor:
"""
Computes the sine of the tensor element-wise.
@@ -2924,36 +2879,6 @@ class Tensor(OpMixin):
# ***** math functions *****
def trunc(self: Tensor) -> Tensor:
"""
Truncates the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
```
"""
return self._apply_uop(UOp.trunc)
def ceil(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards positive infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy())
```
"""
return (self > (b := self.trunc())).where(b+1, b)
def floor(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards negative infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy())
```
"""
return (self < (b := self.trunc())).where(b-1, b)
def round(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise with rounding half to even.
@@ -2964,36 +2889,6 @@ class Tensor(OpMixin):
"""
return ((self > 0) == ((b := self.trunc() / 2.0).trunc() == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True) -> Tensor:
"""
Checks the tensor element-wise to return True where the element is infinity, otherwise returns False
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf().numpy())
```
"""
return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
def isnan(self:Tensor) -> Tensor:
"""
Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan().numpy())
```
"""
return self != self
def isfinite(self:Tensor) -> Tensor:
"""
Checks the tensor element-wise to return True where the element is finite, otherwise returns False
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite().numpy())
```
"""
return (self.isinf()|self.isnan()).logical_not()
def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor:
"""
Linearly interpolates between `self` and `end` by `weight`.
@@ -3007,36 +2902,6 @@ class Tensor(OpMixin):
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
return self + (end - self) * weight
def square(self) -> Tensor:
"""
Squares the tensor element-wise.
Equivalent to `self*self`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
```
"""
return self*self
def clamp(self, min_=None, max_=None) -> Tensor:
"""
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
```
"""
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
ret = (self < min_).where(min_, self) if min_ is not None else self
return (ret > max_).where(max_, ret) if max_ is not None else ret
def clip(self, min_=None, max_=None) -> Tensor:
"""
Alias for `Tensor.clamp`.
"""
return self.clamp(min_, max_)
def sign(self) -> Tensor:
"""
Returns the sign of the tensor element-wise.
@@ -3105,66 +2970,6 @@ class Tensor(OpMixin):
"""
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
def swish(self) -> Tensor:
"""
See `.silu()`
- Paper: https://arxiv.org/abs/1710.05941v1
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
```
"""
return self * self.sigmoid()
def silu(self) -> Tensor:
"""
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
- Paper: https://arxiv.org/abs/1606.08415
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
```
"""
return self.swish() # The SiLU function is also known as the swish function.
def relu6(self) -> Tensor:
"""
Applies the ReLU6 function element-wise.
- Paper: https://arxiv.org/abs/1704.04861v1
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
```
"""
return self.relu() - (self-6).relu()
def hardswish(self) -> Tensor:
"""
Applies the Hardswish function element-wise.
- Paper: https://arxiv.org/abs/1905.02244v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
```
"""
return self * (self+3).relu6() * (1/6)
def tanh(self) -> Tensor:
"""
Applies the Hyperbolic Tangent (tanh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
```
"""
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self) -> Tensor:
"""
Applies the Hyperbolic Sine (sinh) function element-wise.
@@ -3225,16 +3030,6 @@ class Tensor(OpMixin):
"""
return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1) -> Tensor:
"""
Applies the Hardtanh function element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
```
"""
return self.clip(min_val, max_val)
def erf(self) -> Tensor:
"""
Applies error function element-wise.
@@ -3249,41 +3044,6 @@ class Tensor(OpMixin):
t = 1.0 / (1.0 + 0.3275911 * self.abs())
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
def gelu(self) -> Tensor:
"""
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
- Paper: https://arxiv.org/abs/1606.08415v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
```
"""
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
def quick_gelu(self) -> Tensor:
"""
Applies the Sigmoid GELU approximation element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
```
"""
return self * (self * 1.702).sigmoid()
def leaky_relu(self, neg_slope=0.01) -> Tensor:
"""
Applies the Leaky ReLU function element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
```
"""
return (self<0).where(neg_slope*self, self)
def mish(self) -> Tensor:
"""
Applies the Mish function element-wise.