mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
refactor: refactor ConstantInput to be flexible
- refactor to take a function to generate the propore BaseValue to store in its output - refactor BaseTracer to force inheriting tracers to indicate how to build a ConstantInput tracer - remove "as import" for intermediate in hnumpy/tracing.py - update compile to manage python dtypes
This commit is contained in:
@@ -152,10 +152,13 @@ class ConstantInput(IntermediateNode):
|
||||
def __init__(
|
||||
self,
|
||||
constant_data: Any,
|
||||
get_base_value_for_data_func: Callable[
|
||||
[Any], Callable[..., BaseValue]
|
||||
] = get_base_value_for_python_constant_data,
|
||||
) -> None:
|
||||
super().__init__([])
|
||||
|
||||
base_value_class = get_base_value_for_python_constant_data(constant_data)
|
||||
base_value_class = get_base_value_for_data_func(constant_data)
|
||||
|
||||
self._constant_data = constant_data
|
||||
self.outputs = [base_value_class(is_encrypted=False)]
|
||||
|
||||
@@ -36,6 +36,17 @@ class BaseTracer(ABC):
|
||||
"""
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
@abstractmethod
|
||||
def _make_const_input_tracer(self, constant_data: Any) -> "BaseTracer":
|
||||
"""Helper function to create a tracer for a constant input.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The constant to store.
|
||||
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that constant.
|
||||
"""
|
||||
|
||||
def instantiate_output_tracers(
|
||||
self,
|
||||
inputs: Iterable[Union["BaseTracer", Any]],
|
||||
@@ -55,7 +66,7 @@ class BaseTracer(ABC):
|
||||
# For inputs which are actually constant, first convert into a tracer
|
||||
def sanitize(inp):
|
||||
if not isinstance(inp, BaseTracer):
|
||||
return make_const_input_tracer(self.__class__, inp)
|
||||
return self._make_const_input_tracer(inp)
|
||||
return inp
|
||||
|
||||
sanitized_inputs = [sanitize(inp) for inp in inputs]
|
||||
@@ -128,16 +139,3 @@ class BaseTracer(ABC):
|
||||
# the order, we need to do as in __rmul__, ie mostly a copy of __mul__ +
|
||||
# some changes
|
||||
__rmul__ = __mul__
|
||||
|
||||
|
||||
def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Any) -> BaseTracer:
|
||||
"""Helper function to create a tracer for a constant input.
|
||||
|
||||
Args:
|
||||
tracer_class (Type[BaseTracer]): the class of tracer to create a ConstantInput for
|
||||
constant_data (Any): the constant
|
||||
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that constant
|
||||
"""
|
||||
return tracer_class([], ir.ConstantInput(constant_data), 0)
|
||||
|
||||
@@ -17,6 +17,7 @@ from ..common.operator_graph import OPGraph
|
||||
from ..common.optimization.topological import fuse_float_operations
|
||||
from ..common.representation import intermediate as ir
|
||||
from ..hnumpy.tracing import trace_numpy_function
|
||||
from .np_dtypes_helpers import get_base_data_type_for_numpy_or_python_constant_data
|
||||
|
||||
|
||||
def compile_numpy_function_into_op_graph(
|
||||
@@ -74,7 +75,9 @@ def compile_numpy_function_into_op_graph(
|
||||
node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset)
|
||||
|
||||
# Update the graph accordingly: after that, we have the compilable graph
|
||||
op_graph.update_values_with_bounds(node_bounds)
|
||||
op_graph.update_values_with_bounds(
|
||||
node_bounds, get_base_data_type_for_numpy_or_python_constant_data
|
||||
)
|
||||
|
||||
# Make sure the graph can be lowered to MLIR
|
||||
if not is_graph_values_compatible_with_mlir(op_graph):
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
"""File to hold code to manage package and numpy dtypes."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types.base import BaseDataType
|
||||
from ..common.data_types.dtypes_helpers import BASE_DATA_TYPES
|
||||
from ..common.data_types.dtypes_helpers import (
|
||||
BASE_DATA_TYPES,
|
||||
get_base_data_type_for_python_constant_data,
|
||||
get_base_value_for_python_constant_data,
|
||||
)
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
from ..common.data_types.values import BaseValue, ScalarValue
|
||||
|
||||
NUMPY_TO_HDK_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
|
||||
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
|
||||
@@ -92,6 +98,58 @@ def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.d
|
||||
return type_to_return
|
||||
|
||||
|
||||
def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType:
|
||||
"""Helper function to determine the BaseDataType to hold the input constant data.
|
||||
|
||||
Args:
|
||||
constant_data (Any): The constant data for which to determine the
|
||||
corresponding BaseDataType.
|
||||
|
||||
Returns:
|
||||
BaseDataType: The corresponding BaseDataType
|
||||
"""
|
||||
base_dtype: BaseDataType
|
||||
assert isinstance(
|
||||
constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
|
||||
base_dtype = convert_numpy_dtype_to_base_data_type(constant_data)
|
||||
else:
|
||||
base_dtype = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return base_dtype
|
||||
|
||||
|
||||
def get_base_value_for_numpy_or_python_constant_data(
|
||||
constant_data: Any,
|
||||
) -> Callable[..., BaseValue]:
|
||||
"""Helper function to determine the BaseValue and BaseDataType to hold the input constant data.
|
||||
|
||||
This function is able to handle numpy types
|
||||
|
||||
Args:
|
||||
constant_data (Any): The constant data for which to determine the
|
||||
corresponding BaseValue and BaseDataType.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `constant_data` is of an unsupported type.
|
||||
|
||||
Returns:
|
||||
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when called
|
||||
with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method).
|
||||
"""
|
||||
constant_data_value: Callable[..., BaseValue]
|
||||
assert isinstance(
|
||||
constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
), f"Unsupported constant data of type {type(constant_data)}"
|
||||
|
||||
base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data)
|
||||
if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
|
||||
constant_data_value = partial(ScalarValue, data_type=base_dtype)
|
||||
else:
|
||||
constant_data_value = get_base_value_for_python_constant_data(constant_data)
|
||||
return constant_data_value
|
||||
|
||||
|
||||
def get_ufunc_numpy_output_dtype(
|
||||
ufunc: numpy.ufunc,
|
||||
input_dtypes: List[BaseDataType],
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""hnumpy tracing utilities."""
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import numpy
|
||||
@@ -7,11 +8,12 @@ from numpy.typing import DTypeLike
|
||||
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation import intermediate as ir
|
||||
from ..common.representation.intermediate import ArbitraryFunction, ConstantInput
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from .np_dtypes_helpers import (
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_ufunc_numpy_output_dtype,
|
||||
)
|
||||
|
||||
@@ -19,6 +21,11 @@ SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple(
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES
|
||||
)
|
||||
|
||||
NPConstantInput = partial(
|
||||
ConstantInput,
|
||||
get_base_value_for_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
|
||||
class NPTracer(BaseTracer):
|
||||
"""Tracer class for numpy operations."""
|
||||
@@ -55,7 +62,7 @@ class NPTracer(BaseTracer):
|
||||
|
||||
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
|
||||
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
|
||||
traced_computation = ir.ArbitraryFunction(
|
||||
traced_computation = ArbitraryFunction(
|
||||
input_base_value=self.output,
|
||||
arbitrary_func=normalized_numpy_dtype.type,
|
||||
output_dtype=output_dtype,
|
||||
@@ -91,6 +98,9 @@ class NPTracer(BaseTracer):
|
||||
other, SUPPORTED_TYPES_FOR_TRACING
|
||||
)
|
||||
|
||||
def _make_const_input_tracer(self, constant_data: Any) -> "NPTracer":
|
||||
return self.__class__([], NPConstantInput(constant_data), 0)
|
||||
|
||||
@staticmethod
|
||||
def _manage_dtypes(ufunc: numpy.ufunc, *input_tracers: "NPTracer"):
|
||||
output_dtypes = get_ufunc_numpy_output_dtype(
|
||||
@@ -111,7 +121,7 @@ class NPTracer(BaseTracer):
|
||||
common_output_dtypes = self._manage_dtypes(numpy.rint, *input_tracers)
|
||||
assert len(common_output_dtypes) == 1
|
||||
|
||||
traced_computation = ir.ArbitraryFunction(
|
||||
traced_computation = ArbitraryFunction(
|
||||
input_base_value=input_tracers[0].output,
|
||||
arbitrary_func=numpy.rint,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
@@ -133,7 +143,7 @@ class NPTracer(BaseTracer):
|
||||
common_output_dtypes = self._manage_dtypes(numpy.sin, *input_tracers)
|
||||
assert len(common_output_dtypes) == 1
|
||||
|
||||
traced_computation = ir.ArbitraryFunction(
|
||||
traced_computation = ArbitraryFunction(
|
||||
input_base_value=input_tracers[0].output,
|
||||
arbitrary_func=numpy.sin,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
|
||||
@@ -71,6 +71,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 42, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: x + numpy.int32(42), ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: x * 2, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
|
||||
|
||||
Reference in New Issue
Block a user