From 0abcb9aac2c57fd9ef8fc4fe22ecf8a5f9e53055 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 16 Feb 2026 11:35:00 +0800 Subject: [PATCH] move more to mixins (#14780) * move more to mixins * revert * move some * do not change * more * fix tests * Revert "more" This reverts commit d942d59fa4fad1d382008074895bf514fd9b63ef. * go * work * more * work * guard * base --- test/amd/__init__.py | 0 test/mockgpu/amd/emu.py | 2 +- test/null/test_graph_rewrite.py | 8 +- tinygrad/mixin/dtype.py | 2 +- tinygrad/mixin/elementwise.py | 229 ++++++++++++++++++++++++++++++- tinygrad/tensor.py | 234 +------------------------------- tinygrad/uop/ops.py | 1 + 7 files changed, 233 insertions(+), 243 deletions(-) create mode 100644 test/amd/__init__.py diff --git a/test/amd/__init__.py b/test/amd/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/mockgpu/amd/emu.py b/test/mockgpu/amd/emu.py index cc1db823e5..ce9fb8b385 100644 --- a/test/mockgpu/amd/emu.py +++ b/test/mockgpu/amd/emu.py @@ -745,7 +745,7 @@ def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3 | irc.VOP3, ctx: _Ctx) -> UOp: # VOP3 specific fields vdst_reg = ctx.inst_field(type(inst).vdst) - literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None + literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr] abs_bits, neg_bits = getattr(inst, 'abs', 0) or 0, getattr(inst, 'neg', 0) or 0 # VOP3_SDST: v_s_* instructions goes to SGPR diff --git a/test/null/test_graph_rewrite.py b/test/null/test_graph_rewrite.py index bff8c7d055..b6f9749c1c 100644 --- a/test/null/test_graph_rewrite.py +++ b/test/null/test_graph_rewrite.py @@ -251,8 +251,8 @@ class TestSubstitute(unittest.TestCase): # this works because there's nothing above the substituted node def test_sin(self): - a = UOp.variable('a', 0, 10) - b = UOp.variable('b', 0, 10) + a = UOp.variable('a', 0, 10, dtype=dtypes.float) + b = UOp.variable('b', 0, 10, dtype=dtypes.float) ret = a.sin().sin() ret = substitute(ret, {a.sin():b}) self.assertIs(ret, b.sin()) @@ -268,14 +268,14 @@ class TestSubstitute(unittest.TestCase): ret = substitute(ret, {n1:n1.sqrt()}) def test_sin_to_sqrt(self): - a = UOp.variable('a', 0, 10) + a = UOp.variable('a', 0, 10, dtype=dtypes.float) n1 = a.sin() ret = n1.sin() ret = substitute(ret, {a.sin():a.sqrt()}) self.assertIs(ret, a.sqrt().sin()) def test_double_sin_to_sqrt(self): - a = UOp.variable('a', 0, 10) + a = UOp.variable('a', 0, 10, dtype=dtypes.float) n1 = a.sin() ret = n1.sin() # NOTE: this would work if it had gone in the opposite order diff --git a/tinygrad/mixin/dtype.py b/tinygrad/mixin/dtype.py index 9094925207..90e9f4dab5 100644 --- a/tinygrad/mixin/dtype.py +++ b/tinygrad/mixin/dtype.py @@ -28,7 +28,7 @@ class DTypeMixin: print(t.is_floating_point()) ``` """ - return dtypes.is_float(self.dtype) + return dtypes.is_float(self.dtype.base) def float(self) -> Self: """ diff --git a/tinygrad/mixin/elementwise.py b/tinygrad/mixin/elementwise.py index ed1f2b06d0..6aba802970 100644 --- a/tinygrad/mixin/elementwise.py +++ b/tinygrad/mixin/elementwise.py @@ -1,7 +1,8 @@ import math from typing import Self from tinygrad.uop import Ops -from tinygrad.dtype import dtypes, ConstType +from tinygrad.dtype import dtypes, ConstType, least_upper_dtype, least_upper_float +from tinygrad.helpers import polyN from tinygrad.mixin.dtype import DTypeMixin @@ -261,8 +262,18 @@ class ElementwiseMixin(DTypeMixin): def threefry(self, seed: Self) -> Self: return self.alu(Ops.THREEFRY, seed) + def _ensure_float(self) -> Self: + return self if self.is_floating_point() else self.cast(least_upper_float(self.dtype)) + def reciprocal(self) -> Self: - return self.alu(Ops.RECIPROCAL) + """ + Computes `1/x` element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 3., 4.]).reciprocal().numpy()) + ``` + """ + return self._ensure_float().alu(Ops.RECIPROCAL) def trunc(self) -> Self: """ @@ -275,16 +286,73 @@ class ElementwiseMixin(DTypeMixin): return self.alu(Ops.TRUNC) def sqrt(self) -> Self: - return self.alu(Ops.SQRT) + """ + Computes the square root of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 3., 4.]).sqrt().numpy()) + ``` + """ + return self._ensure_float().alu(Ops.SQRT) def sin(self) -> Self: - return self.alu(Ops.SIN) + """ + Computes the sine of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy()) + ``` + """ + return self._ensure_float().alu(Ops.SIN) + + def cos(self) -> Self: + """ + Computes the cosine of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy()) + ``` + """ + if self.is_floating_point(): return ((math.pi/2)-self.cast(least_upper_dtype(self.dtype, dtypes.float32))).sin().cast(self.dtype) + return ((math.pi/2)-self).sin() + + def exp(self) -> Self: + """ + Computes the exponential function element-wise. + + See: https://en.wikipedia.org/wiki/Exponential_function + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., 1., 2., 3.]).exp().numpy()) + ``` + """ + if self.is_floating_point(): + return self.cast(least_upper_dtype(self.dtype, dtypes.float32)).mul(1/math.log(2)).exp2().cast(self.dtype) + return self.mul(1/math.log(2)).exp2() def log2(self) -> Self: - return self.alu(Ops.LOG2) + """ + Computes the base-2 logarithm element-wise. + + See: https://en.wikipedia.org/wiki/Logarithm + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 4., 8.]).log2().numpy()) + ``` + """ + return self._ensure_float().alu(Ops.LOG2) def exp2(self) -> Self: - return self.alu(Ops.EXP2) + """ + Computes the base-2 exponential function element-wise. + + See: https://en.wikipedia.org/wiki/Exponential_function + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., 1., 2., 3.]).exp2().numpy()) + ``` + """ + return self._ensure_float().alu(Ops.EXP2) def pow(self, x: Self | ConstType) -> Self: return self.alu(Ops.POW, self.ufix(x)) @@ -590,3 +658,152 @@ class ElementwiseMixin(DTypeMixin): ``` """ return ((self > 0).eq((b := self.trunc() / 2.0).trunc().eq(b))).where((self - 0.5).ceil(), (self + 0.5).floor()) + + def sign(self) -> Self: + """ + Returns the sign of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy()) + ``` + """ + return self.ne(0).where((self < 0).where(self.const_like(-1), self.const_like(1)), self.const_like(0)) + self * 0 + + def abs(self) -> Self: + """ + Computes the absolute value of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy()) + ``` + """ + return self * self.sign() + + def tan(self) -> Self: + """ + Computes the tangent of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy()) + ``` + """ + return self.sin() / self.cos() + + def asin(self) -> Self: + """ + Computes the inverse sine (arcsine) of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy()) + ``` + """ + # https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46 + coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050] + x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients) + return self.sign() * x + + def acos(self) -> Self: + """ + Computes the inverse cosine (arccosine) of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy()) + ``` + """ + return math.pi / 2 - self.asin() + + def atan(self) -> Self: + """ + Computes the inverse tangent (arctan) of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy()) + ``` + """ + return (self / (1 + self * self).sqrt()).asin() + + def elu(self, alpha=1.0) -> Self: + """ + Applies the Exponential Linear Unit (ELU) function element-wise. + + - Paper: https://arxiv.org/abs/1511.07289v5 + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy()) + ``` + """ + return self.relu() - alpha*(1-self.exp()).relu() + + def celu(self, alpha=1.0) -> Self: + """ + Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise. + + - Paper: https://arxiv.org/abs/1704.07483 + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy()) + ``` + """ + return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0) + + def sinh(self) -> Self: + """ + Applies the Hyperbolic Sine (sinh) function element-wise. + + - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy()) + ``` + """ + return (self.exp() - self.neg().exp()) / 2 + + def cosh(self) -> Self: + """ + Applies the Hyperbolic Cosine (cosh) function element-wise. + + - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy()) + ``` + """ + return (self.exp() + self.neg().exp()) / 2 + + def erf(self) -> Self: + """ + Applies error function element-wise. + + - Described: https://en.wikipedia.org/wiki/Error_function + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy()) + ``` + """ + # https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26 + t = 1.0 / (1.0 + 0.3275911 * self.abs()) + return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp()) + + def softsign(self) -> Self: + """ + Applies the Softsign function element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy()) + ``` + """ + return self / (1 + self.abs()) + + def bitwise_not(self) -> Self: + """ + Computes the bitwise NOT of `self`. + Equivalent to `~self`. + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([True, False]).bitwise_not().numpy()) + ``` + """ + if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") + return self.logical_not() if self.dtype == dtypes.bool else self ^ -1 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1ed16b48a5..275fb56e7d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: import numpy from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten -from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile +from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, is_numpy_ndarray, TracingKey, cpu_profile from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.gradient import compute_gradient from tinygrad.mixin import OpMixin @@ -189,11 +189,12 @@ class Tensor(OpMixin): all_tensors[weakref.ref(ret)] = None return ret - # _binop and alu are used by MathMixin + # _binop, alu, and const_like are used by the mixins def _binop(self, op, x, reverse): lhs,rhs = self._broadcasted(x, reverse) return lhs._apply_uop(lambda *u: u[0].alu(op, *u[1:]), rhs) def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src) + def const_like(self, b:ConstType) -> Tensor: return Tensor(dtypes.as_const(b, self.dtype), self.device, self.dtype, requires_grad=False) def requires_grad_(self, requires_grad=True) -> Tensor: # make the UOp unique if it's a CONST to prevent gradient accumulation bugs with cached const UOps @@ -2844,45 +2845,6 @@ class Tensor(OpMixin): """ return self._apply_uop(UOp.contiguous_backward) - def log2(self) -> Tensor: - """ - Computes the base-2 logarithm element-wise. - - See: https://en.wikipedia.org/wiki/Logarithm - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([1., 2., 4., 8.]).log2().numpy()) - ``` - """ - return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2) - - def exp(self) -> Tensor: - """ - Computes the exponential function element-wise. - - See: https://en.wikipedia.org/wiki/Exponential_function - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([0., 1., 2., 3.]).exp().numpy()) - ``` - """ - # TODO: make it generic, and same thing to log and cos - if self.is_floating_point(): return self.cast(least_upper_dtype(self.dtype, dtypes.float32)).mul(1/math.log(2)).exp2().cast(self.dtype) - # TODO: behavior when DEFAULT_FLOAT is bfloat16 and input is int32? - return self.mul(1/math.log(2)).exp2() - - def exp2(self) -> Tensor: - """ - Computes the base-2 exponential function element-wise. - - See: https://en.wikipedia.org/wiki/Exponential_function - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([0., 1., 2., 3.]).exp2().numpy()) - ``` - """ - return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2) - def logsigmoid(self) -> Tensor: """ Applies the LogSigmoid function element-wise. @@ -2895,80 +2857,6 @@ class Tensor(OpMixin): """ return -(-self).softplus() - def sqrt(self) -> Tensor: - """ - Computes the square root of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([1., 2., 3., 4.]).sqrt().numpy()) - ``` - """ - return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt) - - def sin(self) -> Tensor: - """ - Computes the sine of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy()) - ``` - """ - return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin) - - def cos(self) -> Tensor: - """ - Computes the cosine of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy()) - ``` - """ - if self.is_floating_point(): return ((math.pi/2)-self.cast(least_upper_dtype(self.dtype, dtypes.float32))).sin().cast(self.dtype) - return ((math.pi/2)-self).sin() - - def tan(self) -> Tensor: - """ - Computes the tangent of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy()) - ``` - """ - return self.sin() / self.cos() - - def asin(self) -> Tensor: - """ - Computes the inverse sine (arcsine) of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).asin().numpy()) - ``` - """ - # https://personal.math.ubc.ca/~cbm/aands/page_81.htm 4.4.46 - coefficients = [-0.0012624911, 0.0066700901, -0.0170881256, 0.0308918810, -0.0501743046, 0.0889789874, -0.2145988016, 1.5707963050] - x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients) - return self.sign() * x - - def acos(self) -> Tensor: - """ - Computes the inverse cosine (arccosine) of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).acos().numpy()) - ``` - """ - return math.pi / 2 - self.asin() - - def atan(self) -> Tensor: - """ - Computes the inverse tangent (arctan) of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).atan().numpy()) - ``` - """ - return (self / (1 + self * self).sqrt()).asin() - # ***** math functions ***** def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor: @@ -2984,62 +2872,8 @@ class Tensor(OpMixin): return (self+(((end - self).cast(dtypes.int8) * w_i + (1<> W_PREC)).cast(dtypes.uint8) return self + (end - self) * weight - def sign(self) -> Tensor: - """ - Returns the sign of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy()) - ``` - """ - return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0 - - def abs(self) -> Tensor: - """ - Computes the absolute value of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy()) - ``` - """ - return self * self.sign() - - def reciprocal(self) -> Tensor: - """ - Computes `1/x` element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([1., 2., 3., 4.]).reciprocal().numpy()) - ``` - """ - return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal) - # ***** activation functions ***** - def elu(self, alpha=1.0) -> Tensor: - """ - Applies the Exponential Linear Unit (ELU) function element-wise. - - - Paper: https://arxiv.org/abs/1511.07289v5 - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy()) - ``` - """ - return self.relu() - alpha*(1-self.exp()).relu() - - def celu(self, alpha=1.0) -> Tensor: - """ - Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise. - - - Paper: https://arxiv.org/abs/1704.07483 - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy()) - ``` - """ - return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0) - def selu(self, alpha=1.67326, gamma=1.0507) -> Tensor: """ Applies the Scaled Exponential Linear Unit (SELU) function element-wise. @@ -3052,44 +2886,6 @@ class Tensor(OpMixin): """ return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1)) - def sinh(self) -> Tensor: - """ - Applies the Hyperbolic Sine (sinh) function element-wise. - - - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy()) - ``` - """ - return (self.exp() - self.neg().exp()) / 2 - - def cosh(self) -> Tensor: - """ - Applies the Hyperbolic Cosine (cosh) function element-wise. - - - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy()) - ``` - """ - return (self.exp() + self.neg().exp()) / 2 - - def erf(self) -> Tensor: - """ - Applies error function element-wise. - - - Described: https://en.wikipedia.org/wiki/Error_function - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy()) - ``` - """ - # https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26 - t = 1.0 / (1.0 + 0.3275911 * self.abs()) - return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp()) - def mish(self) -> Tensor: """ Applies the Mish function element-wise. @@ -3112,16 +2908,6 @@ class Tensor(OpMixin): """ return (1/beta) * (self*beta).logaddexp(0.0) - def softsign(self) -> Tensor: - """ - Applies the Softsign function element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy()) - ``` - """ - return self / (1 + self.abs()) - # ***** broadcasted elementwise ops ***** def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True, backward_cast:bool=True) -> tuple[Tensor, Tensor]: @@ -3216,20 +3002,6 @@ class Tensor(OpMixin): a, b = self._broadcasted(x, reverse) return a - a.div(b, rounding_mode="floor") * b - def bitwise_not(self) -> Tensor: - """ - Computes the bitwise NOT of `self`. - Equivalent to `~self`. - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([0, 2, 5, 255], dtype="int8").bitwise_not().numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([True, False]).bitwise_not().numpy()) - ``` - """ - if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return self.logical_not() if self.dtype == dtypes.bool else self ^ -1 - def lshift(self, x:Tensor|int, reverse=False) -> Tensor: """ Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype. diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 0039d54060..9f187e8a82 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -926,6 +926,7 @@ class UPat(OpMixin): def dtype(self) -> DType: return self.match_dtype[0] if self.match_dtype is not None else dtypes.void def _check_dtype(self) -> None: pass + def _ensure_float(self) -> UPat: return self def __reduce__(self): return UPat, (self.op, self.match_dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject, self.location)