mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compilation-configuration): make compilation customizable
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
"""Module for compilation related types."""
|
||||
|
||||
from .artifacts import CompilationArtifacts
|
||||
from .configuration import CompilationConfiguration
|
||||
|
||||
13
hdk/common/compilation/configuration.py
Normal file
13
hdk/common/compilation/configuration.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Module for compilation configuration."""
|
||||
|
||||
|
||||
class CompilationConfiguration:
|
||||
"""Class that allows the compilation process to be customized."""
|
||||
|
||||
enable_topological_optimizations: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_topological_optimizations: bool = True,
|
||||
):
|
||||
self.enable_topological_optimizations = enable_topological_optimizations
|
||||
@@ -6,7 +6,7 @@ from zamalang import CompilerEngine
|
||||
|
||||
from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
|
||||
from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from ..common.mlir.utils import (
|
||||
@@ -23,6 +23,7 @@ def compile_numpy_function_into_op_graph(
|
||||
function_to_trace: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
dataset: Iterator[Tuple[Any, ...]],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> OPGraph:
|
||||
"""Compile a function into an OPGraph.
|
||||
@@ -34,18 +35,30 @@ def compile_numpy_function_into_op_graph(
|
||||
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: compiled function into a graph
|
||||
"""
|
||||
|
||||
# Create default configuration if custom configuration is not specified
|
||||
compilation_configuration = (
|
||||
CompilationConfiguration()
|
||||
if compilation_configuration is None
|
||||
else compilation_configuration
|
||||
)
|
||||
|
||||
# Trace
|
||||
op_graph = trace_numpy_function(function_to_trace, function_parameters)
|
||||
|
||||
# Fuse float operations to have int to int ArbitraryFunction
|
||||
if not check_op_graph_is_integer_program(op_graph):
|
||||
fuse_float_operations(op_graph)
|
||||
# Apply topological optimizations if they are enabled
|
||||
if compilation_configuration.enable_topological_optimizations:
|
||||
# Fuse float operations to have int to int ArbitraryFunction
|
||||
if not check_op_graph_is_integer_program(op_graph):
|
||||
fuse_float_operations(op_graph)
|
||||
|
||||
# TODO: To be removed once we support more than integers
|
||||
offending_non_integer_nodes: List[ir.IntermediateNode] = []
|
||||
@@ -82,6 +95,7 @@ def compile_numpy_function(
|
||||
function_to_trace: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
dataset: Iterator[Tuple[Any, ...]],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> CompilerEngine:
|
||||
"""Main API of hnumpy, to be able to compile an homomorphic program.
|
||||
@@ -93,15 +107,29 @@ def compile_numpy_function(
|
||||
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
CompilerEngine: engine to run and debug the compiled graph
|
||||
"""
|
||||
|
||||
# Create default configuration if custom configuration is not specified
|
||||
compilation_configuration = (
|
||||
CompilationConfiguration()
|
||||
if compilation_configuration is None
|
||||
else compilation_configuration
|
||||
)
|
||||
|
||||
# Compile into an OPGraph
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function_to_trace, function_parameters, dataset, compilation_artifacts
|
||||
function_to_trace,
|
||||
function_parameters,
|
||||
dataset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
# Convert graph to an MLIR representation
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_artifacts_export():
|
||||
function,
|
||||
{"x": EncryptedValue(Integer(7, True))},
|
||||
iter([(-2,), (-1,), (0,), (1,), (2,)]),
|
||||
artifacts,
|
||||
compilation_artifacts=artifacts,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
|
||||
72
tests/common/compilation/test_configuration.py
Normal file
72
tests/common/compilation/test_configuration.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Test file for compilation configuration"""
|
||||
|
||||
from inspect import signature
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from hdk.common.compilation import CompilationConfiguration
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def no_fuse(x):
|
||||
"""No fuse"""
|
||||
return x + 2
|
||||
|
||||
|
||||
def simple_fuse_not_output(x):
|
||||
"""Simple fuse not output"""
|
||||
intermediate = x.astype(numpy.float64)
|
||||
intermediate = intermediate.astype(numpy.uint32)
|
||||
return intermediate + 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,fused",
|
||||
[
|
||||
pytest.param(
|
||||
no_fuse,
|
||||
False,
|
||||
id="no_fuse",
|
||||
),
|
||||
pytest.param(
|
||||
simple_fuse_not_output,
|
||||
True,
|
||||
id="simple_fuse_not_output",
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
# fails because it connot be compiled without topological optimizations
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_enable_topological_optimizations(test_helpers, function_to_trace, fused):
|
||||
"""Test function for enable_topological_optimizations flag of compilation configuration"""
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function_to_trace,
|
||||
{
|
||||
param: EncryptedValue(Integer(32, is_signed=False))
|
||||
for param in signature(function_to_trace).parameters.keys()
|
||||
},
|
||||
iter([(1,), (2,), (3,)]),
|
||||
)
|
||||
op_graph_not_optimized = compile_numpy_function_into_op_graph(
|
||||
function_to_trace,
|
||||
{
|
||||
param: EncryptedValue(Integer(32, is_signed=False))
|
||||
for param in signature(function_to_trace).parameters.keys()
|
||||
},
|
||||
iter([(1,), (2,), (3,)]),
|
||||
compilation_configuration=CompilationConfiguration(enable_topological_optimizations=False),
|
||||
)
|
||||
|
||||
graph = op_graph.graph
|
||||
not_optimized_graph = op_graph_not_optimized.graph
|
||||
|
||||
if fused:
|
||||
assert not test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
|
||||
assert len(graph) < len(not_optimized_graph)
|
||||
else:
|
||||
assert test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
|
||||
assert len(graph) == len(not_optimized_graph)
|
||||
Reference in New Issue
Block a user