mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(hnp-tracing): add support for ufunc routing to NPTracer
- start tracing numpy.rint and manage dtypes - update BaseTracer to accept iterables as inputs, because NPTracer does not get list givent the way numpy sends arguments to the functions
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""This file holds the code that can be shared between tracers"""
|
||||
|
||||
from abc import ABC
|
||||
from typing import List, Tuple, Type, Union
|
||||
from typing import Iterable, List, Tuple, Type, Union
|
||||
|
||||
from ..data_types import BaseValue
|
||||
from ..data_types.scalars import Scalars
|
||||
@@ -17,17 +17,17 @@ class BaseTracer(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: List["BaseTracer"],
|
||||
inputs: Iterable["BaseTracer"],
|
||||
traced_computation: ir.IntermediateNode,
|
||||
output_index: int,
|
||||
) -> None:
|
||||
self.inputs = inputs
|
||||
self.inputs = list(inputs)
|
||||
self.traced_computation = traced_computation
|
||||
self.output = traced_computation.outputs[output_index]
|
||||
|
||||
def instantiate_output_tracers(
|
||||
self,
|
||||
inputs: List[Union["BaseTracer", Scalars]],
|
||||
inputs: Iterable[Union["BaseTracer", Scalars]],
|
||||
computation_to_trace: Type[ir.IntermediateNode],
|
||||
) -> Tuple["BaseTracer", ...]:
|
||||
"""Helper functions to instantiate all output BaseTracer for a given computation
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""File to hold code to manage package and numpy dtypes"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
@@ -94,14 +94,14 @@ def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dty
|
||||
|
||||
def get_ufunc_numpy_output_dtype(
|
||||
ufunc: numpy.ufunc,
|
||||
input_dtypes: Union[List[numpy.dtype], List[BaseDataType]],
|
||||
input_dtypes: List[BaseDataType],
|
||||
) -> List[numpy.dtype]:
|
||||
"""Function to record the output dtype of a numpy.ufunc given some input types
|
||||
|
||||
Args:
|
||||
ufunc (numpy.ufunc): The numpy.ufunc whose output types need to be recorded
|
||||
input_dtypes (Union[List[numpy.dtype], List[BaseDataType]]): Either numpy or common dtypes
|
||||
in the same order as they will be used with the ufunc inputs
|
||||
input_dtypes (List[BaseDataType]): Common dtypes in the same order as they will be used with
|
||||
the ufunc inputs
|
||||
|
||||
Returns:
|
||||
List[numpy.dtype]: The ordered numpy dtypes of the ufunc outputs
|
||||
@@ -110,12 +110,7 @@ def get_ufunc_numpy_output_dtype(
|
||||
len(input_dtypes) == ufunc.nin
|
||||
), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}"
|
||||
|
||||
input_dtypes = [
|
||||
numpy.dtype(convert_common_dtype_to_numpy_dtype(dtype))
|
||||
if not isinstance(dtype, numpy.dtype)
|
||||
else dtype
|
||||
for dtype in input_dtypes
|
||||
]
|
||||
input_numpy_dtypes = [convert_common_dtype_to_numpy_dtype(dtype) for dtype in input_dtypes]
|
||||
|
||||
# Store numpy old error settings and ignore all errors in this function
|
||||
# We ignore errors as we may call functions with invalid inputs just to get the proper output
|
||||
@@ -123,7 +118,7 @@ def get_ufunc_numpy_output_dtype(
|
||||
old_numpy_err_settings = numpy.seterr(all="ignore")
|
||||
|
||||
dummy_inputs = tuple(
|
||||
dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_dtypes
|
||||
dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_numpy_dtypes
|
||||
)
|
||||
|
||||
outputs = ufunc(*dummy_inputs)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""hnumpy tracing utilities"""
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Mapping
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
@@ -8,12 +8,28 @@ from ..common.data_types import BaseValue
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation import intermediate as ir
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from .np_dtypes_helpers import convert_numpy_dtype_to_common_dtype
|
||||
from .np_dtypes_helpers import (
|
||||
convert_numpy_dtype_to_common_dtype,
|
||||
get_ufunc_numpy_output_dtype,
|
||||
)
|
||||
|
||||
|
||||
class NPTracer(BaseTracer):
|
||||
"""Tracer class for numpy operations"""
|
||||
|
||||
def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs):
|
||||
"""
|
||||
Catch calls to numpy ufunc and routes them to tracing functions if supported
|
||||
read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
|
||||
"""
|
||||
if method == "__call__":
|
||||
tracing_func = self.get_tracing_func_for_np_ufunc(ufunc)
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), f"hnumpy does not support **kwargs currently for numpy ufuncs, ufunc: {ufunc}"
|
||||
return tracing_func(self, *input_tracers, **kwargs)
|
||||
raise NotImplementedError("Only __call__ method is supported currently")
|
||||
|
||||
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
|
||||
"""Support numpy astype feature, for now it only accepts a dtype and no additional
|
||||
parameters, *args and **kwargs are accepted for interface compatibility only
|
||||
@@ -36,9 +52,66 @@ class NPTracer(BaseTracer):
|
||||
arbitrary_func=normalized_numpy_dtype.type,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
output_tracer = NPTracer([self], traced_computation=traced_computation, output_index=0)
|
||||
output_tracer = self.__class__(
|
||||
[self], traced_computation=traced_computation, output_index=0
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
@staticmethod
|
||||
def get_tracing_func_for_np_ufunc(ufunc: numpy.ufunc) -> Callable:
|
||||
"""Get the tracing function for a numpy ufunc
|
||||
|
||||
Args:
|
||||
ufunc (numpy.ufunc): The numpy ufunc that will be traced
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Raised if the passed ufunc is not supported by NPTracer
|
||||
|
||||
Returns:
|
||||
Callable: the tracing function that needs to be called to trace ufunc
|
||||
"""
|
||||
tracing_func = NPTracer.UFUNC_ROUTING.get(ufunc, None)
|
||||
if tracing_func is None:
|
||||
raise NotImplementedError(
|
||||
f"NPTracer does not yet manage the following ufunc: {ufunc.__name__}"
|
||||
)
|
||||
return tracing_func
|
||||
|
||||
@staticmethod
|
||||
def _manage_dtypes(ufunc: numpy.ufunc, *input_tracers: "NPTracer"):
|
||||
output_dtypes = get_ufunc_numpy_output_dtype(
|
||||
ufunc, [input_tracer.output.data_type for input_tracer in input_tracers]
|
||||
)
|
||||
common_output_dtypes = [
|
||||
convert_numpy_dtype_to_common_dtype(dtype) for dtype in output_dtypes
|
||||
]
|
||||
return common_output_dtypes
|
||||
|
||||
def rint(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.rint
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert len(input_tracers) == 1
|
||||
common_output_dtypes = self._manage_dtypes(numpy.rint, *input_tracers)
|
||||
assert len(common_output_dtypes) == 1
|
||||
|
||||
traced_computation = ir.ArbitraryFunction(
|
||||
input_base_value=input_tracers[0].output,
|
||||
arbitrary_func=numpy.rint,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
op_kwargs=kwargs,
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
input_tracers, traced_computation=traced_computation, output_index=0
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
UFUNC_ROUTING: Mapping[numpy.ufunc, Callable] = {
|
||||
numpy.rint: rint,
|
||||
}
|
||||
|
||||
|
||||
def trace_numpy_function(
|
||||
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
|
||||
|
||||
@@ -201,3 +201,81 @@ def test_tracing_astype(
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output))
|
||||
assert expected_output == evaluated_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,inputs,expected_output_node,expected_output_value",
|
||||
[
|
||||
# We cannot call trace_numpy_function on some numpy function as getting the signature for
|
||||
# these functions fails, so we wrap it in a lambda
|
||||
# pylint: disable=unnecessary-lambda
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Integer(7, is_signed=False))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedValue(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Integer(32, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedValue(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Integer(64, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedValue(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Integer(128, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
None,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Float(64))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedValue(Float(64)),
|
||||
),
|
||||
# The next test case is only for coverage purposes, to trigger the unsupported method
|
||||
# exception handling
|
||||
pytest.param(
|
||||
lambda x: numpy.add.reduce(x),
|
||||
{"x": EncryptedValue(Integer(32, is_signed=True))},
|
||||
None,
|
||||
None,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
def test_trace_hnumpy_supported_ufuncs(
|
||||
function_to_trace, inputs, expected_output_node, expected_output_value
|
||||
):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"np_ufunc,expected_tracing_func",
|
||||
[
|
||||
pytest.param(numpy.rint, tracing.NPTracer.rint),
|
||||
# There is a need to test the case where the function fails, I chose numpy.conjugate which
|
||||
# works on complex types, as we don't talk about complex types for now this looks like a
|
||||
# good long term candidate to check for an unsupported function
|
||||
pytest.param(
|
||||
numpy.conjugate, None, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError)
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nptracer_get_tracing_func_for_np_ufunc(np_ufunc, expected_tracing_func):
|
||||
"""Test NPTracer get_tracing_func_for_np_ufunc"""
|
||||
assert tracing.NPTracer.get_tracing_func_for_np_ufunc(np_ufunc) == expected_tracing_func
|
||||
|
||||
Reference in New Issue
Block a user