feat: add univariate extension to wrap arbitrary functions

This commit is contained in:
Umut
2022-05-17 16:08:52 +02:00
parent 56085204e6
commit 574446e6f5
4 changed files with 114 additions and 1 deletions

View File

@@ -15,6 +15,6 @@ from .compilation import (
Server,
compiler,
)
from .extensions import LookupTable
from .extensions import LookupTable, univariate
from .mlir.utils import MAXIMUM_BIT_WIDTH
from .representation import Graph

View File

@@ -3,3 +3,4 @@ Provide additional features that are not present in numpy.
"""
from .table import LookupTable
from .univariate import univariate

View File

@@ -0,0 +1,67 @@
"""
Declaration of `univariate` function.
"""
from typing import Any, Callable, Union
import numpy as np
from ..dtypes import Float
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
def univariate(
function: Callable[[Any], Any],
) -> Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
"""
Wrap a univariate function so that it is traced into a single generic node.
Args:
function (Callable[[Any], Any]):
univariate function to wrap
Returns:
Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
another univariate function that can be called with a Tracer as well
"""
def wrapper(x: Union[Tracer, Any]) -> Union[Tracer, Any]:
"""
Evaluate or trace wrapped univariate function.
Args:
x (Union[Tracer, Any]):
input of the function
Returns:
Union[Tracer, Any]:
result of tracing or evaluation
"""
if isinstance(x, Tracer):
dtype = (
{64: np.float64, 32: np.float32, 16: np.float16}[x.output.dtype.bit_width]
if isinstance(x.output.dtype, Float)
else np.int64
)
sample = dtype(1) if x.output.shape == () else np.ones(x.output.shape, dtype=dtype)
evaluation = function(sample)
output_value = Value.of(evaluation, is_encrypted=x.output.is_encrypted)
if output_value.shape != x.output.shape:
raise ValueError(f"Function {function.__name__} cannot be used with cnp.univariate")
computation = Node.generic(
function.__name__,
[x.output],
output_value,
lambda x: function(x), # pylint: disable=unnecessary-lambda
)
return Tracer(computation, [x])
return function(x)
return wrapper

View File

@@ -141,6 +141,20 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
# pylint: enable=invalid-name,too-many-locals,too-many-statements
def deterministic_unary_function(x):
"""
An example deterministic unary function.
"""
def per_element(element):
result = 0
for i in range(element):
result += i
return result
return np.vectorize(per_element)(x)
@pytest.mark.parametrize(
"function,parameters",
[
@@ -461,6 +475,13 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
},
id="x + np.zeros_like(x)",
),
pytest.param(
lambda x: cnp.univariate(deterministic_unary_function)(x),
{
"x": {"status": "encrypted", "range": [0, 10]},
},
id="cnp.univariate(deterministic_unary_function)(x)",
),
],
)
def test_others(function, parameters, helpers):
@@ -615,3 +636,27 @@ return %4
# pylint: enable=line-too-long
str(excinfo.value),
)
def test_others_bad_univariate(helpers):
"""
Test univariate with bad function.
"""
configuration = helpers.configuration()
def bad_univariate(x):
return np.array([x, x, x])
@cnp.compiler({"x": "encrypted"})
def f(x):
return cnp.univariate(bad_univariate)(x)
with pytest.raises(ValueError) as excinfo:
inputset = range(10)
f.compile(inputset, configuration)
helpers.check_str(
"Function bad_univariate cannot be used with cnp.univariate",
str(excinfo.value),
)