mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: add univariate extension to wrap arbitrary functions
This commit is contained in:
@@ -15,6 +15,6 @@ from .compilation import (
|
||||
Server,
|
||||
compiler,
|
||||
)
|
||||
from .extensions import LookupTable
|
||||
from .extensions import LookupTable, univariate
|
||||
from .mlir.utils import MAXIMUM_BIT_WIDTH
|
||||
from .representation import Graph
|
||||
|
||||
@@ -3,3 +3,4 @@ Provide additional features that are not present in numpy.
|
||||
"""
|
||||
|
||||
from .table import LookupTable
|
||||
from .univariate import univariate
|
||||
|
||||
67
concrete/numpy/extensions/univariate.py
Normal file
67
concrete/numpy/extensions/univariate.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Declaration of `univariate` function.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import Float
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
|
||||
def univariate(
|
||||
function: Callable[[Any], Any],
|
||||
) -> Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
|
||||
"""
|
||||
Wrap a univariate function so that it is traced into a single generic node.
|
||||
|
||||
Args:
|
||||
function (Callable[[Any], Any]):
|
||||
univariate function to wrap
|
||||
|
||||
Returns:
|
||||
Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
|
||||
another univariate function that can be called with a Tracer as well
|
||||
"""
|
||||
|
||||
def wrapper(x: Union[Tracer, Any]) -> Union[Tracer, Any]:
|
||||
"""
|
||||
Evaluate or trace wrapped univariate function.
|
||||
|
||||
Args:
|
||||
x (Union[Tracer, Any]):
|
||||
input of the function
|
||||
|
||||
Returns:
|
||||
Union[Tracer, Any]:
|
||||
result of tracing or evaluation
|
||||
"""
|
||||
|
||||
if isinstance(x, Tracer):
|
||||
dtype = (
|
||||
{64: np.float64, 32: np.float32, 16: np.float16}[x.output.dtype.bit_width]
|
||||
if isinstance(x.output.dtype, Float)
|
||||
else np.int64
|
||||
)
|
||||
|
||||
sample = dtype(1) if x.output.shape == () else np.ones(x.output.shape, dtype=dtype)
|
||||
evaluation = function(sample)
|
||||
|
||||
output_value = Value.of(evaluation, is_encrypted=x.output.is_encrypted)
|
||||
if output_value.shape != x.output.shape:
|
||||
raise ValueError(f"Function {function.__name__} cannot be used with cnp.univariate")
|
||||
|
||||
computation = Node.generic(
|
||||
function.__name__,
|
||||
[x.output],
|
||||
output_value,
|
||||
lambda x: function(x), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
return Tracer(computation, [x])
|
||||
|
||||
return function(x)
|
||||
|
||||
return wrapper
|
||||
@@ -141,6 +141,20 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
|
||||
# pylint: enable=invalid-name,too-many-locals,too-many-statements
|
||||
|
||||
|
||||
def deterministic_unary_function(x):
|
||||
"""
|
||||
An example deterministic unary function.
|
||||
"""
|
||||
|
||||
def per_element(element):
|
||||
result = 0
|
||||
for i in range(element):
|
||||
result += i
|
||||
return result
|
||||
|
||||
return np.vectorize(per_element)(x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters",
|
||||
[
|
||||
@@ -461,6 +475,13 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
|
||||
},
|
||||
id="x + np.zeros_like(x)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: cnp.univariate(deterministic_unary_function)(x),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 10]},
|
||||
},
|
||||
id="cnp.univariate(deterministic_unary_function)(x)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_others(function, parameters, helpers):
|
||||
@@ -615,3 +636,27 @@ return %4
|
||||
# pylint: enable=line-too-long
|
||||
str(excinfo.value),
|
||||
)
|
||||
|
||||
|
||||
def test_others_bad_univariate(helpers):
|
||||
"""
|
||||
Test univariate with bad function.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
def bad_univariate(x):
|
||||
return np.array([x, x, x])
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
return cnp.univariate(bad_univariate)(x)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
inputset = range(10)
|
||||
f.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
"Function bad_univariate cannot be used with cnp.univariate",
|
||||
str(excinfo.value),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user