diff --git a/tinygrad/mixin/math.py b/tinygrad/mixin/math.py index ca2f761e53..ef30d883d7 100644 --- a/tinygrad/mixin/math.py +++ b/tinygrad/mixin/math.py @@ -1,3 +1,4 @@ +import math from typing import Self from tinygrad.uop import Ops from tinygrad.dtype import dtypes, ConstType @@ -258,6 +259,13 @@ class MathMixin: return self.alu(Ops.RECIPROCAL) def trunc(self): + """ + Truncates the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy()) + ``` + """ return self.alu(Ops.TRUNC) def sqrt(self): @@ -277,3 +285,232 @@ class MathMixin: def __pow__(self, x: Self | ConstType): return self.pow(x) + + def square(self): + """ + Squares the tensor element-wise. + Equivalent to `self*self`. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy()) + ``` + """ + return self * self + + def clamp(self, min_=None, max_=None): + """ + Clips (clamps) the values in the tensor between `min_` and `max_` element-wise. + If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy()) + ``` + """ + if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None") + ret = (self < min_).where(min_, self) if min_ is not None else self + return (ret > max_).where(max_, ret) if max_ is not None else ret + + def clip(self, min_=None, max_=None): + """Alias for `Tensor.clamp`.""" + return self.clamp(min_, max_) + + def isnan(self): + """ + Checks the tensor element-wise to return True where the element is NaN, otherwise returns False + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan().numpy()) + ``` + """ + return self != self + + def isinf(self, detect_positive: bool = True, detect_negative: bool = True): + """ + Checks the tensor element-wise to return True where the element is infinity, otherwise returns False + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf().numpy()) + ``` + """ + return self.eq(float("inf")) * detect_positive + self.eq(float("-inf")) * detect_negative + + def isfinite(self): + """ + Checks the tensor element-wise to return True where the element is finite, otherwise returns False + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite().numpy()) + ``` + """ + return (self.isinf() | self.isnan()).logical_not() + + def ceil(self): + """ + Rounds the tensor element-wise towards positive infinity. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy()) + ``` + """ + return (self > (b := self.trunc())).where(b+1, b) + + def floor(self): + """ + Rounds the tensor element-wise towards negative infinity. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy()) + ``` + """ + return (self < (b := self.trunc())).where(b-1, b) + + def relu(self): + """ + Applies the Rectified Linear Unit (ReLU) function element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy()) + ``` + """ + # NOTE: if you write this as self.maximum(0) the gradient is wrong, passing through half when self is 0 + return (self > 0).where(self, 0) + + def sigmoid(self): + """ + Applies the Sigmoid function element-wise. + + - Described: https://en.wikipedia.org/wiki/Sigmoid_function + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy()) + ``` + """ + return (1 + (self * (-1/math.log(2))).exp2()).reciprocal() + + def relu6(self): + """ + Applies the ReLU6 function element-wise. + + - Paper: https://arxiv.org/abs/1704.04861v1 + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy()) + ``` + """ + return self.relu() - (self-6).relu() + + def hardswish(self): + """ + Applies the Hardswish function element-wise. + + - Paper: https://arxiv.org/abs/1905.02244v5 + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy()) + ``` + """ + return self * (self+3).relu6() * (1/6) + + def hardsigmoid(self, alpha: float = 1/6, beta: float = 0.5): + """ + Applies the Hardsigmoid function element-wise. + NOTE: default `alpha` and `beta` values are taken from torch + + - See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardsigmoid().numpy()) + ``` + """ + return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu() + + def hardtanh(self, min_val=-1, max_val=1): + """ + Applies the Hardtanh function element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy()) + ``` + """ + return self.clip(min_val, max_val) + + def leaky_relu(self, neg_slope=0.01): + """ + Applies the Leaky ReLU function element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy()) + ``` + """ + return (self < 0).where(neg_slope*self, self) + + def tanh(self): + """ + Applies the Hyperbolic Tangent (tanh) function element-wise. + + - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy()) + ``` + """ + return 2.0 * ((2.0 * self).sigmoid()) - 1.0 + + def quick_gelu(self): + """ + Applies the Sigmoid GELU approximation element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy()) + ``` + """ + return self * (self * 1.702).sigmoid() + + def gelu(self): + """ + Applies the Gaussian Error Linear Unit (GELU) function element-wise. + + - Paper: https://arxiv.org/abs/1606.08415v5 + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy()) + ``` + """ + return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh()) + + def swish(self): + """ + See `.silu()` + + - Paper: https://arxiv.org/abs/1710.05941v1 + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy()) + ``` + """ + return self * self.sigmoid() + + def silu(self): + """ + Applies the Sigmoid Linear Unit (SiLU) function element-wise. + + - Paper: https://arxiv.org/abs/1606.08415 + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy()) + ``` + """ + return self.swish() # The SiLU function is also known as the swish function. + + def rsqrt(self): + """ + Computes the reciprocal of the square root of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 3., 4.]).rsqrt().numpy()) + ``` + """ + return self.sqrt().reciprocal() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1a5c25959a..d6b68b683b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -190,8 +190,9 @@ class Tensor(OpMixin): lhs,rhs = self._broadcasted(x, reverse) return lhs._apply_uop(fxn, rhs) - # _binop is used by MathTrait + # _binop and alu are used by MathMixin def _binop(self, op, x, reverse): return self._apply_broadcasted_uop(lambda *u: UOp.alu(u[0], op, *u[1:]), x, reverse) + def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src) def requires_grad_(self, requires_grad=True) -> Tensor: self.requires_grad = requires_grad @@ -2790,29 +2791,6 @@ class Tensor(OpMixin): """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2) - def relu(self) -> Tensor: - """ - Applies the Rectified Linear Unit (ReLU) function element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy()) - ``` - """ - # NOTE: if you write this as self.maximum(0) the gradient is wrong, passing through half when self is 0 - return (self>0).where(self, 0) - - def sigmoid(self) -> Tensor: - """ - Applies the Sigmoid function element-wise. - - - Described: https://en.wikipedia.org/wiki/Sigmoid_function - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy()) - ``` - """ - return (1 + (self * (-1/math.log(2))).exp2()).reciprocal() - def logsigmoid(self) -> Tensor: """ Applies the LogSigmoid function element-wise. @@ -2825,19 +2803,6 @@ class Tensor(OpMixin): """ return -(-self).softplus() - def hardsigmoid(self, alpha:float=1/6, beta:float=0.5) -> Tensor: - """ - Applies the Hardsigmoid function element-wise. - NOTE: default `alpha` and `beta` values are taken from torch - - - See: https://pytorch.org/docs/stable/generated/torch.nn.functional.hardsigmoid.html - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardsigmoid().numpy()) - ``` - """ - return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu() - def sqrt(self) -> Tensor: """ Computes the square root of the tensor element-wise. @@ -2848,16 +2813,6 @@ class Tensor(OpMixin): """ return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt) - def rsqrt(self) -> Tensor: - """ - Computes the reciprocal of the square root of the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([1., 2., 3., 4.]).rsqrt().numpy()) - ``` - """ - return self.sqrt().reciprocal() - def sin(self) -> Tensor: """ Computes the sine of the tensor element-wise. @@ -2924,36 +2879,6 @@ class Tensor(OpMixin): # ***** math functions ***** - def trunc(self: Tensor) -> Tensor: - """ - Truncates the tensor element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy()) - ``` - """ - return self._apply_uop(UOp.trunc) - - def ceil(self: Tensor) -> Tensor: - """ - Rounds the tensor element-wise towards positive infinity. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy()) - ``` - """ - return (self > (b := self.trunc())).where(b+1, b) - - def floor(self: Tensor) -> Tensor: - """ - Rounds the tensor element-wise towards negative infinity. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy()) - ``` - """ - return (self < (b := self.trunc())).where(b-1, b) - def round(self: Tensor) -> Tensor: """ Rounds the tensor element-wise with rounding half to even. @@ -2964,36 +2889,6 @@ class Tensor(OpMixin): """ return ((self > 0) == ((b := self.trunc() / 2.0).trunc() == b)).where((self - 0.5).ceil(), (self + 0.5).floor()) - def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True) -> Tensor: - """ - Checks the tensor element-wise to return True where the element is infinity, otherwise returns False - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf().numpy()) - ``` - """ - return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative - - def isnan(self:Tensor) -> Tensor: - """ - Checks the tensor element-wise to return True where the element is NaN, otherwise returns False - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan().numpy()) - ``` - """ - return self != self - - def isfinite(self:Tensor) -> Tensor: - """ - Checks the tensor element-wise to return True where the element is finite, otherwise returns False - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite().numpy()) - ``` - """ - return (self.isinf()|self.isnan()).logical_not() - def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor: """ Linearly interpolates between `self` and `end` by `weight`. @@ -3007,36 +2902,6 @@ class Tensor(OpMixin): return (self+(((end - self).cast(dtypes.int8) * w_i + (1<> W_PREC)).cast(dtypes.uint8) return self + (end - self) * weight - def square(self) -> Tensor: - """ - Squares the tensor element-wise. - Equivalent to `self*self`. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy()) - ``` - """ - return self*self - - def clamp(self, min_=None, max_=None) -> Tensor: - """ - Clips (clamps) the values in the tensor between `min_` and `max_` element-wise. - If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy()) - ``` - """ - if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None") - ret = (self < min_).where(min_, self) if min_ is not None else self - return (ret > max_).where(max_, ret) if max_ is not None else ret - - def clip(self, min_=None, max_=None) -> Tensor: - """ - Alias for `Tensor.clamp`. - """ - return self.clamp(min_, max_) - def sign(self) -> Tensor: """ Returns the sign of the tensor element-wise. @@ -3105,66 +2970,6 @@ class Tensor(OpMixin): """ return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1)) - def swish(self) -> Tensor: - """ - See `.silu()` - - - Paper: https://arxiv.org/abs/1710.05941v1 - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy()) - ``` - """ - return self * self.sigmoid() - - def silu(self) -> Tensor: - """ - Applies the Sigmoid Linear Unit (SiLU) function element-wise. - - - Paper: https://arxiv.org/abs/1606.08415 - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy()) - ``` - """ - return self.swish() # The SiLU function is also known as the swish function. - - def relu6(self) -> Tensor: - """ - Applies the ReLU6 function element-wise. - - - Paper: https://arxiv.org/abs/1704.04861v1 - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy()) - ``` - """ - return self.relu() - (self-6).relu() - - def hardswish(self) -> Tensor: - """ - Applies the Hardswish function element-wise. - - - Paper: https://arxiv.org/abs/1905.02244v5 - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy()) - ``` - """ - return self * (self+3).relu6() * (1/6) - - def tanh(self) -> Tensor: - """ - Applies the Hyperbolic Tangent (tanh) function element-wise. - - - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy()) - ``` - """ - return 2.0 * ((2.0 * self).sigmoid()) - 1.0 - def sinh(self) -> Tensor: """ Applies the Hyperbolic Sine (sinh) function element-wise. @@ -3225,16 +3030,6 @@ class Tensor(OpMixin): """ return (self + (self.square() - 1).sqrt()).log() - def hardtanh(self, min_val=-1, max_val=1) -> Tensor: - """ - Applies the Hardtanh function element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy()) - ``` - """ - return self.clip(min_val, max_val) - def erf(self) -> Tensor: """ Applies error function element-wise. @@ -3249,41 +3044,6 @@ class Tensor(OpMixin): t = 1.0 / (1.0 + 0.3275911 * self.abs()) return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp()) - def gelu(self) -> Tensor: - """ - Applies the Gaussian Error Linear Unit (GELU) function element-wise. - - - Paper: https://arxiv.org/abs/1606.08415v5 - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy()) - ``` - """ - return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh()) - - def quick_gelu(self) -> Tensor: - """ - Applies the Sigmoid GELU approximation element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy()) - ``` - """ - return self * (self * 1.702).sigmoid() - - def leaky_relu(self, neg_slope=0.01) -> Tensor: - """ - Applies the Leaky ReLU function element-wise. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy()) - ``` - """ - return (self<0).where(neg_slope*self, self) - def mish(self) -> Tensor: """ Applies the Mish function element-wise.