mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
committed by
Benoit Chevallier
parent
b582e68cd0
commit
e9c5ce27bb
@@ -1,7 +1,7 @@
|
||||
"""hnumpy tracing utilities."""
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
@@ -43,7 +43,7 @@ class NPTracer(BaseTracer):
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), f"hnumpy does not support **kwargs currently for numpy ufuncs, ufunc: {ufunc}"
|
||||
return tracing_func(self, *input_tracers, **kwargs)
|
||||
return tracing_func(*input_tracers, **kwargs)
|
||||
raise NotImplementedError("Only __call__ method is supported currently")
|
||||
|
||||
def __array_function__(self, func, _types, args, kwargs):
|
||||
@@ -130,8 +130,9 @@ class NPTracer(BaseTracer):
|
||||
]
|
||||
return common_output_dtypes
|
||||
|
||||
@classmethod
|
||||
def _unary_operator(
|
||||
self, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
|
||||
cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
|
||||
) -> "NPTracer":
|
||||
"""Function to trace an unary operator.
|
||||
|
||||
@@ -139,7 +140,7 @@ class NPTracer(BaseTracer):
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert len(input_tracers) == 1
|
||||
common_output_dtypes = self._manage_dtypes(unary_operator, *input_tracers)
|
||||
common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers)
|
||||
assert len(common_output_dtypes) == 1
|
||||
|
||||
traced_computation = ArbitraryFunction(
|
||||
@@ -149,93 +150,13 @@ class NPTracer(BaseTracer):
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name=unary_operator_string,
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
output_tracer = cls(
|
||||
input_tracers,
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def rint(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.rint.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.rint, "np.rint", *input_tracers, **kwargs)
|
||||
|
||||
def sin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.sin.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.sin, "np.sin", *input_tracers, **kwargs)
|
||||
|
||||
def cos(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.cos.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.cos, "np.cos", *input_tracers, **kwargs)
|
||||
|
||||
def tan(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.tan.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.tan, "np.tan", *input_tracers, **kwargs)
|
||||
|
||||
def arcsin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.arcsin.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.arcsin, "np.arcsin", *input_tracers, **kwargs)
|
||||
|
||||
def arccos(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.arccos.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.arccos, "np.arccos", *input_tracers, **kwargs)
|
||||
|
||||
def arctan(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.arctan.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.arctan, "np.arctan", *input_tracers, **kwargs)
|
||||
|
||||
def exp(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.exp.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.exp, "np.exp", *input_tracers, **kwargs)
|
||||
|
||||
def expm1(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.expm1.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.expm1, "np.expm1", *input_tracers, **kwargs)
|
||||
|
||||
def exp2(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.exp2.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.exp2, "np.exp2", *input_tracers, **kwargs)
|
||||
|
||||
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.dot.
|
||||
|
||||
@@ -261,24 +182,124 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {
|
||||
numpy.rint: rint,
|
||||
numpy.sin: sin,
|
||||
numpy.cos: cos,
|
||||
numpy.tan: tan,
|
||||
numpy.arcsin: arcsin,
|
||||
numpy.arccos: arccos,
|
||||
numpy.arctan: arctan,
|
||||
numpy.exp: exp,
|
||||
numpy.expm1: expm1,
|
||||
numpy.exp2: exp2,
|
||||
}
|
||||
LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [
|
||||
# The commented functions are functions which don't work for the moment, often
|
||||
# if not always because they require more than a single argument
|
||||
# numpy.absolute,
|
||||
# numpy.add,
|
||||
numpy.arccos,
|
||||
numpy.arccosh,
|
||||
numpy.arcsin,
|
||||
numpy.arcsinh,
|
||||
numpy.arctan,
|
||||
# numpy.arctan2,
|
||||
numpy.arctanh,
|
||||
# numpy.bitwise_and,
|
||||
# numpy.bitwise_or,
|
||||
# numpy.bitwise_xor,
|
||||
numpy.cbrt,
|
||||
numpy.ceil,
|
||||
# numpy.conjugate,
|
||||
# numpy.copysign,
|
||||
numpy.cos,
|
||||
numpy.cosh,
|
||||
numpy.deg2rad,
|
||||
numpy.degrees,
|
||||
# numpy.divmod,
|
||||
# numpy.equal,
|
||||
numpy.exp,
|
||||
numpy.exp2,
|
||||
numpy.expm1,
|
||||
numpy.fabs,
|
||||
# numpy.float_power,
|
||||
numpy.floor,
|
||||
# numpy.floor_divide,
|
||||
# numpy.fmax,
|
||||
# numpy.fmin,
|
||||
# numpy.fmod,
|
||||
# numpy.frexp,
|
||||
# numpy.gcd,
|
||||
# numpy.greater,
|
||||
# numpy.greater_equal,
|
||||
# numpy.heaviside,
|
||||
# numpy.hypot,
|
||||
# numpy.invert,
|
||||
# numpy.isfinite,
|
||||
# numpy.isinf,
|
||||
# numpy.isnan,
|
||||
# numpy.isnat,
|
||||
# numpy.lcm,
|
||||
# numpy.ldexp,
|
||||
# numpy.left_shift,
|
||||
# numpy.less,
|
||||
# numpy.less_equal,
|
||||
numpy.log,
|
||||
numpy.log10,
|
||||
numpy.log1p,
|
||||
numpy.log2,
|
||||
# numpy.logaddexp,
|
||||
# numpy.logaddexp2,
|
||||
# numpy.logical_and,
|
||||
# numpy.logical_not,
|
||||
# numpy.logical_or,
|
||||
# numpy.logical_xor,
|
||||
# numpy.matmul,
|
||||
# numpy.maximum,
|
||||
# numpy.minimum,
|
||||
# numpy.modf,
|
||||
# numpy.multiply,
|
||||
# numpy.negative,
|
||||
# numpy.nextafter,
|
||||
# numpy.not_equal,
|
||||
# numpy.positive,
|
||||
# numpy.power,
|
||||
numpy.rad2deg,
|
||||
numpy.radians,
|
||||
# numpy.reciprocal,
|
||||
# numpy.remainder,
|
||||
# numpy.right_shift,
|
||||
numpy.rint,
|
||||
# numpy.sign,
|
||||
# numpy.signbit,
|
||||
numpy.sin,
|
||||
numpy.sinh,
|
||||
numpy.spacing,
|
||||
numpy.sqrt,
|
||||
# numpy.square,
|
||||
# numpy.subtract,
|
||||
numpy.tan,
|
||||
numpy.tanh,
|
||||
# numpy.true_divide,
|
||||
numpy.trunc,
|
||||
]
|
||||
|
||||
# We build UFUNC_ROUTING dynamically after the creation of the class,
|
||||
# because of some limits of python or our unability to do it properly
|
||||
# in the class with techniques which are compatible with the different
|
||||
# coding checks we use
|
||||
UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {}
|
||||
|
||||
FUNC_ROUTING: Dict[Callable, Callable] = {
|
||||
numpy.dot: dot,
|
||||
}
|
||||
|
||||
|
||||
def _get_fun(function: numpy.ufunc):
|
||||
"""Helper function to wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
|
||||
|
||||
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
|
||||
# dynamically
|
||||
# pylint: disable=protected-access
|
||||
return lambda *input_tracers, **kwargs: NPTracer._unary_operator(
|
||||
function, f"np.{function.__name__}", *input_tracers, **kwargs
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
# We are populating NPTracer.UFUNC_ROUTING dynamically
|
||||
NPTracer.UFUNC_ROUTING = {fun: _get_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC}
|
||||
|
||||
|
||||
def trace_numpy_function(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
) -> OPGraph:
|
||||
|
||||
@@ -230,31 +230,6 @@ def test_tracing_astype(
|
||||
assert expected_output == evaluated_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace",
|
||||
[
|
||||
# We cannot call trace_numpy_function on some numpy function as getting the signature for
|
||||
# these functions fails, so we wrap it in a lambda
|
||||
# pylint: disable=unnecessary-lambda
|
||||
pytest.param(lambda x: numpy.rint(x)),
|
||||
pytest.param(lambda x: numpy.sin(x)),
|
||||
pytest.param(lambda x: numpy.cos(x)),
|
||||
pytest.param(lambda x: numpy.tan(x)),
|
||||
pytest.param(lambda x: numpy.arcsin(x)),
|
||||
pytest.param(lambda x: numpy.arccos(x)),
|
||||
pytest.param(lambda x: numpy.arctan(x)),
|
||||
pytest.param(lambda x: numpy.exp(x)),
|
||||
pytest.param(lambda x: numpy.expm1(x)),
|
||||
pytest.param(lambda x: numpy.exp2(x)),
|
||||
# The next test case is only for coverage purposes, to trigger the unsupported method
|
||||
# exception handling
|
||||
pytest.param(
|
||||
lambda x: numpy.add.reduce(x),
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"inputs,expected_output_node,expected_output_value",
|
||||
[
|
||||
@@ -286,16 +261,40 @@ def test_tracing_astype(
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_trace_hnumpy_supported_ufuncs(
|
||||
function_to_trace, inputs, expected_output_node, expected_output_value
|
||||
):
|
||||
def test_trace_hnumpy_supported_ufuncs(inputs, expected_output_node, expected_output_value):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
for function_to_trace_def in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
# pylint: disable=unnecessary-lambda,cell-var-from-loop
|
||||
function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731
|
||||
# pylint: enable=unnecessary-lambda,cell-var-from-loop
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
|
||||
|
||||
|
||||
def test_trace_hnumpy_ufuncs_not_supported():
|
||||
"""Testing a failure case of trace_numpy_function"""
|
||||
inputs = {"x": EncryptedScalar(Integer(128, is_signed=True))}
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
# pylint: disable=unnecessary-lambda
|
||||
function_to_trace = lambda x: numpy.add.reduce(x) # noqa: E731
|
||||
# pylint: enable=unnecessary-lambda
|
||||
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert "Only __call__ method is supported currently" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -351,31 +350,23 @@ def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expec
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"np_function,expected_tracing_func",
|
||||
[
|
||||
pytest.param(numpy.rint, tracing.NPTracer.rint),
|
||||
pytest.param(numpy.sin, tracing.NPTracer.sin),
|
||||
pytest.param(numpy.cos, tracing.NPTracer.cos),
|
||||
pytest.param(numpy.tan, tracing.NPTracer.tan),
|
||||
pytest.param(numpy.arcsin, tracing.NPTracer.arcsin),
|
||||
pytest.param(numpy.arccos, tracing.NPTracer.arccos),
|
||||
pytest.param(numpy.arctan, tracing.NPTracer.arctan),
|
||||
pytest.param(numpy.exp, tracing.NPTracer.exp),
|
||||
pytest.param(numpy.expm1, tracing.NPTracer.expm1),
|
||||
pytest.param(numpy.exp2, tracing.NPTracer.exp2),
|
||||
pytest.param(numpy.dot, tracing.NPTracer.dot),
|
||||
# There is a need to test the case where the function fails, I chose numpy.conjugate which
|
||||
# works on complex types, as we don't talk about complex types for now this looks like a
|
||||
# good long term candidate to check for an unsupported function
|
||||
pytest.param(
|
||||
numpy.conjugate, None, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError)
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nptracer_get_tracing_func_for_np_functions(np_function, expected_tracing_func):
|
||||
def test_nptracer_get_tracing_func_for_np_functions():
|
||||
"""Test NPTracer get_tracing_func_for_np_function"""
|
||||
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
|
||||
|
||||
for np_function in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
|
||||
expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function]
|
||||
|
||||
assert (
|
||||
tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
|
||||
)
|
||||
|
||||
|
||||
def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
|
||||
"""Check NPTracer in case of not-implemented function"""
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate)
|
||||
|
||||
assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user