From 574446e6f500574b47d458c1923217b1fd566cfe Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 17 May 2022 16:08:52 +0200 Subject: [PATCH] feat: add univariate extension to wrap arbitrary functions --- concrete/numpy/__init__.py | 2 +- concrete/numpy/extensions/__init__.py | 1 + concrete/numpy/extensions/univariate.py | 67 +++++++++++++++++++++++++ tests/execution/test_others.py | 45 +++++++++++++++++ 4 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 concrete/numpy/extensions/univariate.py diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index b37898cb5..2dd37e127 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -15,6 +15,6 @@ from .compilation import ( Server, compiler, ) -from .extensions import LookupTable +from .extensions import LookupTable, univariate from .mlir.utils import MAXIMUM_BIT_WIDTH from .representation import Graph diff --git a/concrete/numpy/extensions/__init__.py b/concrete/numpy/extensions/__init__.py index c841c80fd..67981a314 100644 --- a/concrete/numpy/extensions/__init__.py +++ b/concrete/numpy/extensions/__init__.py @@ -3,3 +3,4 @@ Provide additional features that are not present in numpy. """ from .table import LookupTable +from .univariate import univariate diff --git a/concrete/numpy/extensions/univariate.py b/concrete/numpy/extensions/univariate.py new file mode 100644 index 000000000..bbe0df3e6 --- /dev/null +++ b/concrete/numpy/extensions/univariate.py @@ -0,0 +1,67 @@ +""" +Declaration of `univariate` function. +""" + +from typing import Any, Callable, Union + +import numpy as np + +from ..dtypes import Float +from ..representation import Node +from ..tracing import Tracer +from ..values import Value + + +def univariate( + function: Callable[[Any], Any], +) -> Callable[[Union[Tracer, Any]], Union[Tracer, Any]]: + """ + Wrap a univariate function so that it is traced into a single generic node. + + Args: + function (Callable[[Any], Any]): + univariate function to wrap + + Returns: + Callable[[Union[Tracer, Any]], Union[Tracer, Any]]: + another univariate function that can be called with a Tracer as well + """ + + def wrapper(x: Union[Tracer, Any]) -> Union[Tracer, Any]: + """ + Evaluate or trace wrapped univariate function. + + Args: + x (Union[Tracer, Any]): + input of the function + + Returns: + Union[Tracer, Any]: + result of tracing or evaluation + """ + + if isinstance(x, Tracer): + dtype = ( + {64: np.float64, 32: np.float32, 16: np.float16}[x.output.dtype.bit_width] + if isinstance(x.output.dtype, Float) + else np.int64 + ) + + sample = dtype(1) if x.output.shape == () else np.ones(x.output.shape, dtype=dtype) + evaluation = function(sample) + + output_value = Value.of(evaluation, is_encrypted=x.output.is_encrypted) + if output_value.shape != x.output.shape: + raise ValueError(f"Function {function.__name__} cannot be used with cnp.univariate") + + computation = Node.generic( + function.__name__, + [x.output], + output_value, + lambda x: function(x), # pylint: disable=unnecessary-lambda + ) + return Tracer(computation, [x]) + + return function(x) + + return wrapper diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index 11c6b566e..d688eae61 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -141,6 +141,20 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator(): # pylint: enable=invalid-name,too-many-locals,too-many-statements +def deterministic_unary_function(x): + """ + An example deterministic unary function. + """ + + def per_element(element): + result = 0 + for i in range(element): + result += i + return result + + return np.vectorize(per_element)(x) + + @pytest.mark.parametrize( "function,parameters", [ @@ -461,6 +475,13 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator(): }, id="x + np.zeros_like(x)", ), + pytest.param( + lambda x: cnp.univariate(deterministic_unary_function)(x), + { + "x": {"status": "encrypted", "range": [0, 10]}, + }, + id="cnp.univariate(deterministic_unary_function)(x)", + ), ], ) def test_others(function, parameters, helpers): @@ -615,3 +636,27 @@ return %4 # pylint: enable=line-too-long str(excinfo.value), ) + + +def test_others_bad_univariate(helpers): + """ + Test univariate with bad function. + """ + + configuration = helpers.configuration() + + def bad_univariate(x): + return np.array([x, x, x]) + + @cnp.compiler({"x": "encrypted"}) + def f(x): + return cnp.univariate(bad_univariate)(x) + + with pytest.raises(ValueError) as excinfo: + inputset = range(10) + f.compile(inputset, configuration) + + helpers.check_str( + "Function bad_univariate cannot be used with cnp.univariate", + str(excinfo.value), + )