From c19ef0fcce5645df71e6ae812099ae6792362bb7 Mon Sep 17 00:00:00 2001 From: Diogo Date: Thu, 25 May 2023 12:04:56 -0400 Subject: [PATCH] Add sin/cos/tan (#794) * added sin/cos/tan * fix lint * added onnx ops support --- extra/onnx_ops.py | 9 +++++++++ test/external/external_test_onnx_backend.py | 10 +++++----- test/test_ops.py | 8 ++++++++ tinygrad/codegen/cstyle.py | 1 + tinygrad/codegen/llvmir.py | 1 + tinygrad/mlops.py | 7 +++++++ tinygrad/ops.py | 2 +- tinygrad/runtime/ops_cpu.py | 2 +- tinygrad/runtime/ops_torch.py | 2 +- tinygrad/tensor.py | 4 +++- 10 files changed, 37 insertions(+), 9 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 115c30b312..0ac558e386 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -143,6 +143,15 @@ Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed def LogSoftmax(input, axis=-1): return input.log_softmax(axis) def Clip(input, min=-3.4e38, max=3.4e38): return input.clip(min, max) +import math + +def Sin(x): return x.sin() +def Cos(x): return x.cos() +def Tan(x): return x.tan() +def Cosh(x): return (math.e ** x + math.e ** -x) / 2 +def Sinh(x): return (math.e ** x - math.e ** -x) / 2 +def Tanh(x): return Sinh(x) / Cosh(x) + def Less(x, y): return (xy).numpy().astype(bool) diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 4cd2144036..48f42a04c3 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -87,11 +87,11 @@ backend_test.exclude('test_asin_*') backend_test.exclude('test_asinh_*') backend_test.exclude('test_atan_*') backend_test.exclude('test_atanh_*') -backend_test.exclude('test_cos_*') -backend_test.exclude('test_cosh_*') -backend_test.exclude('test_sin_*') -backend_test.exclude('test_sinh_*') -backend_test.exclude('test_tan_*') +# backend_test.include('test_cos_*') +# backend_test.include('test_cosh_*') +# backend_test.exclude('test_sin_*') +# backend_test.include('test_sinh_*') +# backend_test.include('test_tanh_*') # no boolean ops (2d, 3d, 4d) backend_test.exclude('test_and*') diff --git a/test/test_ops.py b/test/test_ops.py index 90a46dbb47..e5e0711f17 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -131,6 +131,14 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x) def test_sqrt(self): helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0) + + def test_sin(self): + helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0) + def test_cos(self): + helper_test_op([(45,65)], lambda x: x.cos(), Tensor.cos, a=0) + def test_tan(self): + helper_test_op([(45,65)], lambda x: x.tan(), Tensor.tan, a=0) + def test_relu(self): helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu) def test_relu_exact(self): diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index d59050a632..a26fee0b84 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -50,6 +50,7 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.EXP: lambda x: f"native_exp({x})" if NATIVE_EXPLOG else f"exp({x})", UnaryOps.LOG: lambda x: f"native_log({x})" if NATIVE_EXPLOG else f"log({x})", + UnaryOps.SIN: lambda x: f"sin({x})", BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})", diff --git a/tinygrad/codegen/llvmir.py b/tinygrad/codegen/llvmir.py index 1493e30334..da7ed665f9 100644 --- a/tinygrad/codegen/llvmir.py +++ b/tinygrad/codegen/llvmir.py @@ -22,6 +22,7 @@ render_llvm = { code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)), + UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=('fast',)), BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), BinaryOps.MUL: lambda builder,x,y: builder.fmul(x,y, flags=('fast',)), diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 7460b35b42..57f10b4890 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -3,6 +3,7 @@ from tinygrad.helpers import argsort, ShapeType from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps from tinygrad.tensor import Function from tinygrad.lazy import LazyBuffer +import math class Contiguous(Function): def forward(self, x): return x.contiguous() @@ -17,6 +18,12 @@ class Cast(Function): # ************* unary ops ************* +class Sin(Function): + def forward(self, x: LazyBuffer) -> LazyBuffer: + self.x = x + return x.unary_op(UnaryOps.SIN) + def backward(self, grad: LazyBuffer) -> LazyBuffer: + return self.x.const_like(math.pi / 2).binary_op(BinaryOps.SUB, self.x).unary_op(UnaryOps.SIN).binary_op(BinaryOps.MUL, grad) # NOTE: maximum(x, 0) behaves differently where x=0 class Relu(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4199c99f11..477f7f47ba 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -8,7 +8,7 @@ from tinygrad.runtime.lib import RawBuffer, RawConst # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly -class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto() # noqa: E702 +class UnaryOps(Enum): NOOP = auto(); EXP = auto(); LOG = auto(); CAST = auto(); SIN = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class FusedOps(Enum): MULACC = auto() # noqa: E702 diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index ba837b1bfb..c76883e0cc 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -27,7 +27,7 @@ def einsum_mulacc(einsum, get_strides, expand): return mulacc numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ - UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP: np.exp, UnaryOps.LOG: np.log, UnaryOps.CAST: lambda x,y: x.astype(y.np), + UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP: np.exp, UnaryOps.LOG: np.log, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin, BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to, MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)], diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index addb063468..74b183c8d5 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -9,7 +9,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if geten type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32} torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ - UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), + UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.CAST: lambda x,y: x.type(next(k for k,v in type_map.items() if v==y)), UnaryOps.SIN: torch.sin, BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), FusedOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d49c1ab06d..b824b0ed02 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -396,7 +396,9 @@ class Tensor: def log(self): return mlops.Log.apply(self) def exp(self): return mlops.Exp.apply(self) def relu(self): return mlops.Relu.apply(self) - + def sin(self): return mlops.Sin.apply(self) + def cos(self): return ((math.pi/2)-self).sin() + def tan(self): return self.sin() / self.cos() # ***** math functions (unary) ***** def __neg__(self): return 0.0-self