From fb694a63ebba46542da75d09db4b35fdfd9621df Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 30 Oct 2024 18:12:28 -0400 Subject: [PATCH] Tensor.erf (#7419) the same one used in onnx and the one in bert. --- docs/tensor/elementwise.md | 1 + extra/models/bert.py | 7 +------ extra/onnx_ops.py | 15 ++------------- test/test_ops.py | 6 ++++++ tinygrad/tensor.py | 16 +++++++++++++++- 5 files changed, 25 insertions(+), 20 deletions(-) diff --git a/docs/tensor/elementwise.md b/docs/tensor/elementwise.md index 08b95d8300..91eccb5651 100644 --- a/docs/tensor/elementwise.md +++ b/docs/tensor/elementwise.md @@ -42,6 +42,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t ::: tinygrad.Tensor.asinh ::: tinygrad.Tensor.acosh ::: tinygrad.Tensor.hardtanh +::: tinygrad.Tensor.erf ::: tinygrad.Tensor.gelu ::: tinygrad.Tensor.quick_gelu ::: tinygrad.Tensor.leakyrelu diff --git a/extra/models/bert.py b/extra/models/bert.py index 8c91e27a8d..c1eb33f85c 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -231,12 +231,7 @@ class BertOutput: return hidden_states def gelu(x): - return x * 0.5 * (1.0 + erf(x / 1.41421)) - -# approximation of the error function -def erf(x): - t = (1 + 0.3275911 * x.abs()).reciprocal() - return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp()) + return x * 0.5 * (1.0 + (x / 1.41421).erf()) class BertIntermediate: def __init__(self, hidden_size, intermediate_size): diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index f61c4d2dbd..252d10ba08 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -8,7 +8,7 @@ import numpy as np tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh", - "Elu", "Celu", "Xor", "Round"} + "Elu", "Celu", "Xor", "Round", "Erf"} # **************** Free Ops **************** @@ -43,7 +43,7 @@ def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, v if value_string is not None or value_strings is not None: raise NotImplementedError('value_string or value_strings not implemented for Constant op') def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1) -def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2))) +def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu()) def PRelu(X:Tensor, slope:Tensor): slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE @@ -505,17 +505,6 @@ def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1): cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs)) return cond.where(values[1], values[0]) -def Erf(x: Tensor): - t = 1.0 / (1.0 + 0.3275911 * x.abs()) - term1 = 0.254829592 * t - term2 = -0.284496736 * t ** 2 - term3 = 1.421413741 * t ** 3 - term4 = -1.453152027 * t ** 4 - term5 = 1.061405429 * t ** 5 - y = (term1 + term2 + term3 + term4 + term5) - z = 1.0 - y * (-x * x).exp() - return (x > 0).where(z, -z) - def Compress(inp: Tensor, condition: Tensor, axis=None): if axis is None: inp = inp.flatten() diff --git a/test/test_ops.py b/test/test_ops.py index debefc7dfe..1e6303eedf 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -647,6 +647,12 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6) helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6) + def test_erf(self): + helper_test_op([(45,65)], torch.erf, Tensor.erf) + helper_test_op([(45,65)], torch.erf, Tensor.erf, low=300, high=400) + helper_test_op([(45,65)], torch.erf, Tensor.erf, low=-400, high=-300) + helper_test_op([()], torch.erf, Tensor.erf) + def test_gelu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu) helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=400) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9125de6537..fbae229603 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ from collections import defaultdict from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup -from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch +from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN from tinygrad.multi import MultiLazyBuffer from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable, SimpleMathTrait from tinygrad.device import Device, Buffer, BufferOptions @@ -2651,6 +2651,20 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method """ return self.clip(min_val, max_val) + def erf(self): + """ + Applies error function element-wise. + + - Described: https://en.wikipedia.org/wiki/Error_function + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).erf().numpy()) + ``` + """ + # https://personal.math.ubc.ca/~cbm/aands/page_299.htm 7.1.26 + t = 1.0 / (1.0 + 0.3275911 * self.abs()) + return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp()) + def gelu(self): """ Applies the Gaussian Error Linear Unit (GELU) function element-wise.