move more to mixins (#14780)

* move more to mixins

* revert

* move some

* do not change

* more

* fix tests

* Revert "more"

This reverts commit d942d59fa4.

* go

* work

* more

* work

* guard

* base
This commit is contained in:
George Hotz
2026-02-16 11:35:00 +08:00
committed by GitHub
parent 8e7c5f5b09
commit 0abcb9aac2
7 changed files with 233 additions and 243 deletions

0
test/amd/__init__.py Normal file
View File

View File

@@ -745,7 +745,7 @@ def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3 | irc.VOP3, ctx: _Ctx) -> UOp:
# VOP3 specific fields
vdst_reg = ctx.inst_field(type(inst).vdst)
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
abs_bits, neg_bits = getattr(inst, 'abs', 0) or 0, getattr(inst, 'neg', 0) or 0
# VOP3_SDST: v_s_* instructions goes to SGPR

View File

@@ -251,8 +251,8 @@ class TestSubstitute(unittest.TestCase):
# this works because there's nothing above the substituted node
def test_sin(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
a = UOp.variable('a', 0, 10, dtype=dtypes.float)
b = UOp.variable('b', 0, 10, dtype=dtypes.float)
ret = a.sin().sin()
ret = substitute(ret, {a.sin():b})
self.assertIs(ret, b.sin())
@@ -268,14 +268,14 @@ class TestSubstitute(unittest.TestCase):
ret = substitute(ret, {n1:n1.sqrt()})
def test_sin_to_sqrt(self):
a = UOp.variable('a', 0, 10)
a = UOp.variable('a', 0, 10, dtype=dtypes.float)
n1 = a.sin()
ret = n1.sin()
ret = substitute(ret, {a.sin():a.sqrt()})
self.assertIs(ret, a.sqrt().sin())
def test_double_sin_to_sqrt(self):
a = UOp.variable('a', 0, 10)
a = UOp.variable('a', 0, 10, dtype=dtypes.float)
n1 = a.sin()
ret = n1.sin()
# NOTE: this would work if it had gone in the opposite order

View File

@@ -28,7 +28,7 @@ class DTypeMixin:
print(t.is_floating_point())
```
"""
return dtypes.is_float(self.dtype)
return dtypes.is_float(self.dtype.base)
def float(self) -> Self:
"""

View File

@@ -1,7 +1,8 @@
import math
from typing import Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType
from tinygrad.dtype import dtypes, ConstType, least_upper_dtype, least_upper_float
from tinygrad.helpers import polyN
from tinygrad.mixin.dtype import DTypeMixin
@@ -261,8 +262,18 @@ class ElementwiseMixin(DTypeMixin):
def threefry(self, seed: Self) -> Self:
return self.alu(Ops.THREEFRY, seed)
def _ensure_float(self) -> Self:
return self if self.is_floating_point() else self.cast(least_upper_float(self.dtype))
def reciprocal(self) -> Self:
return self.alu(Ops.RECIPROCAL)
"""
Computes `1/x` element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
```
"""
return self._ensure_float().alu(Ops.RECIPROCAL)
def trunc(self) -> Self:
"""
@@ -275,16 +286,73 @@ class ElementwiseMixin(DTypeMixin):
return self.alu(Ops.TRUNC)
def sqrt(self) -> Self:
return self.alu(Ops.SQRT)
"""
Computes the square root of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
```
"""
return self._ensure_float().alu(Ops.SQRT)
def sin(self) -> Self:
return self.alu(Ops.SIN)
"""
Computes the sine of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
```
"""
return self._ensure_float().alu(Ops.SIN)
def cos(self) -> Self:
"""
Computes the cosine of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
```
"""
if self.is_floating_point(): return ((math.pi/2)-self.cast(least_upper_dtype(self.dtype, dtypes.float32))).sin().cast(self.dtype)
return ((math.pi/2)-self).sin()
def exp(self) -> Self:
"""
Computes the exponential function element-wise.
See: https://en.wikipedia.org/wiki/Exponential_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., 1., 2., 3.]).exp().numpy())
```
"""
if self.is_floating_point():
return self.cast(least_upper_dtype(self.dtype, dtypes.float32)).mul(1/math.log(2)).exp2().cast(self.dtype)
return self.mul(1/math.log(2)).exp2()
def log2(self) -> Self:
return self.alu(Ops.LOG2)
"""
Computes the base-2 logarithm element-wise.
See: https://en.wikipedia.org/wiki/Logarithm
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 4., 8.]).log2().numpy())
```
"""
return self._ensure_float().alu(Ops.LOG2)
def exp2(self) -> Self:
return self.alu(Ops.EXP2)
"""
Computes the base-2 exponential function element-wise.
See: https://en.wikipedia.org/wiki/Exponential_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
```
"""
return self._ensure_float().alu(Ops.EXP2)
def pow(self, x: Self | ConstType) -> Self:
return self.alu(Ops.POW, self.ufix(x))
@@ -590,3 +658,152 @@ class ElementwiseMixin(DTypeMixin):
```
"""
return ((self > 0).eq((b := self.trunc() / 2.0).trunc().eq(b))).where((self - 0.5).ceil(), (self + 0.5).floor())
def sign(self) -> Self:
"""
Returns the sign of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
```
"""
return self.ne(0).where((self < 0).where(self.const_like(-1), self.const_like(1)), self.const_like(0)) + self * 0
def abs(self) -> Self:
"""
Computes the absolute value of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
```
"""
return self * self.sign()
def tan(self) -> Self:
"""
Computes the tangent of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
```
"""
return self.sin() / self.cos()
def asin(self) -> Self:
"""
Computes the inverse sine (arcsine) of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy())
```
"""
# https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46
coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050]
x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
return self.sign() * x
def acos(self) -> Self:
"""
Computes the inverse cosine (arccosine) of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy())
```
"""
return math.pi / 2 - self.asin()
def atan(self) -> Self:
"""
Computes the inverse tangent (arctan) of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy())
```
"""
return (self / (1 + self * self).sqrt()).asin()
def elu(self, alpha=1.0) -> Self:
"""
Applies the Exponential Linear Unit (ELU) function element-wise.
- Paper: https://arxiv.org/abs/1511.07289v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
```
"""
return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0) -> Self:
"""
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
- Paper: https://arxiv.org/abs/1704.07483
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
```
"""
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def sinh(self) -> Self:
"""
Applies the Hyperbolic Sine (sinh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
```
"""
return (self.exp() - self.neg().exp()) / 2
def cosh(self) -> Self:
"""
Applies the Hyperbolic Cosine (cosh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
```
"""
return (self.exp() + self.neg().exp()) / 2
def erf(self) -> Self:
"""
Applies error function element-wise.
- Described: https://en.wikipedia.org/wiki/Error_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy())
```
"""
# https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26
t = 1.0 / (1.0 + 0.3275911 * self.abs())
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
def softsign(self) -> Self:
"""
Applies the Softsign function element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
```
"""
return self / (1 + self.abs())
def bitwise_not(self) -> Self:
"""
Computes the bitwise NOT of `self`.
Equivalent to `~self`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, False]).bitwise_not().numpy())
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1

View File

@@ -7,7 +7,7 @@ if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
@@ -189,11 +189,12 @@ class Tensor(OpMixin):
all_tensors[weakref.ref(ret)] = None
return ret
# _binop and alu are used by MathMixin
# _binop, alu, and const_like are used by the mixins
def _binop(self, op, x, reverse):
lhs,rhs = self._broadcasted(x, reverse)
return lhs._apply_uop(lambda *u: u[0].alu(op, *u[1:]), rhs)
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
def const_like(self, b:ConstType) -> Tensor: return Tensor(dtypes.as_const(b, self.dtype), self.device, self.dtype, requires_grad=False)
def requires_grad_(self, requires_grad=True) -> Tensor:
# make the UOp unique if it's a CONST to prevent gradient accumulation bugs with cached const UOps
@@ -2844,45 +2845,6 @@ class Tensor(OpMixin):
"""
return self._apply_uop(UOp.contiguous_backward)
def log2(self) -> Tensor:
"""
Computes the base-2 logarithm element-wise.
See: https://en.wikipedia.org/wiki/Logarithm
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 4., 8.]).log2().numpy())
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
def exp(self) -> Tensor:
"""
Computes the exponential function element-wise.
See: https://en.wikipedia.org/wiki/Exponential_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., 1., 2., 3.]).exp().numpy())
```
"""
# TODO: make it generic, and same thing to log and cos
if self.is_floating_point(): return self.cast(least_upper_dtype(self.dtype, dtypes.float32)).mul(1/math.log(2)).exp2().cast(self.dtype)
# TODO: behavior when DEFAULT_FLOAT is bfloat16 and input is int32?
return self.mul(1/math.log(2)).exp2()
def exp2(self) -> Tensor:
"""
Computes the base-2 exponential function element-wise.
See: https://en.wikipedia.org/wiki/Exponential_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
def logsigmoid(self) -> Tensor:
"""
Applies the LogSigmoid function element-wise.
@@ -2895,80 +2857,6 @@ class Tensor(OpMixin):
"""
return -(-self).softplus()
def sqrt(self) -> Tensor:
"""
Computes the square root of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
def sin(self) -> Tensor:
"""
Computes the sine of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
def cos(self) -> Tensor:
"""
Computes the cosine of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
```
"""
if self.is_floating_point(): return ((math.pi/2)-self.cast(least_upper_dtype(self.dtype, dtypes.float32))).sin().cast(self.dtype)
return ((math.pi/2)-self).sin()
def tan(self) -> Tensor:
"""
Computes the tangent of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
```
"""
return self.sin() / self.cos()
def asin(self) -> Tensor:
"""
Computes the inverse sine (arcsine) of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy())
```
"""
# https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46
coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050]
x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
return self.sign() * x
def acos(self) -> Tensor:
"""
Computes the inverse cosine (arccosine) of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy())
```
"""
return math.pi / 2 - self.asin()
def atan(self) -> Tensor:
"""
Computes the inverse tangent (arctan) of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy())
```
"""
return (self / (1 + self * self).sqrt()).asin()
# ***** math functions *****
def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor:
@@ -2984,62 +2872,8 @@ class Tensor(OpMixin):
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
return self + (end - self) * weight
def sign(self) -> Tensor:
"""
Returns the sign of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
```
"""
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
def abs(self) -> Tensor:
"""
Computes the absolute value of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
```
"""
return self * self.sign()
def reciprocal(self) -> Tensor:
"""
Computes `1/x` element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal)
# ***** activation functions *****
def elu(self, alpha=1.0) -> Tensor:
"""
Applies the Exponential Linear Unit (ELU) function element-wise.
- Paper: https://arxiv.org/abs/1511.07289v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
```
"""
return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0) -> Tensor:
"""
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
- Paper: https://arxiv.org/abs/1704.07483
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
```
"""
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def selu(self, alpha=1.67326, gamma=1.0507) -> Tensor:
"""
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
@@ -3052,44 +2886,6 @@ class Tensor(OpMixin):
"""
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
def sinh(self) -> Tensor:
"""
Applies the Hyperbolic Sine (sinh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
```
"""
return (self.exp() - self.neg().exp()) / 2
def cosh(self) -> Tensor:
"""
Applies the Hyperbolic Cosine (cosh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
```
"""
return (self.exp() + self.neg().exp()) / 2
def erf(self) -> Tensor:
"""
Applies error function element-wise.
- Described: https://en.wikipedia.org/wiki/Error_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy())
```
"""
# https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26
t = 1.0 / (1.0 + 0.3275911 * self.abs())
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
def mish(self) -> Tensor:
"""
Applies the Mish function element-wise.
@@ -3112,16 +2908,6 @@ class Tensor(OpMixin):
"""
return (1/beta) * (self*beta).logaddexp(0.0)
def softsign(self) -> Tensor:
"""
Applies the Softsign function element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
```
"""
return self / (1 + self.abs())
# ***** broadcasted elementwise ops *****
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True, backward_cast:bool=True) -> tuple[Tensor, Tensor]:
@@ -3216,20 +3002,6 @@ class Tensor(OpMixin):
a, b = self._broadcasted(x, reverse)
return a - a.div(b, rounding_mode="floor") * b
def bitwise_not(self) -> Tensor:
"""
Computes the bitwise NOT of `self`.
Equivalent to `~self`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([True, False]).bitwise_not().numpy())
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
def lshift(self, x:Tensor|int, reverse=False) -> Tensor:
"""
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.

View File

@@ -926,6 +926,7 @@ class UPat(OpMixin):
def dtype(self) -> DType: return self.match_dtype[0] if self.match_dtype is not None else dtypes.void
def _check_dtype(self) -> None: pass
def _ensure_float(self) -> UPat: return self
def __reduce__(self):
return UPat, (self.op, self.match_dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject, self.location)