diff --git a/hdk/common/compilation/__init__.py b/hdk/common/compilation/__init__.py index 3dba32973..59e402417 100644 --- a/hdk/common/compilation/__init__.py +++ b/hdk/common/compilation/__init__.py @@ -1,3 +1,4 @@ """Module for compilation related types.""" from .artifacts import CompilationArtifacts +from .configuration import CompilationConfiguration diff --git a/hdk/common/compilation/configuration.py b/hdk/common/compilation/configuration.py new file mode 100644 index 000000000..648a6c9a8 --- /dev/null +++ b/hdk/common/compilation/configuration.py @@ -0,0 +1,13 @@ +"""Module for compilation configuration.""" + + +class CompilationConfiguration: + """Class that allows the compilation process to be customized.""" + + enable_topological_optimizations: bool + + def __init__( + self, + enable_topological_optimizations: bool = True, + ): + self.enable_topological_optimizations = enable_topological_optimizations diff --git a/hdk/hnumpy/compile.py b/hdk/hnumpy/compile.py index 03f899448..1f629b674 100644 --- a/hdk/hnumpy/compile.py +++ b/hdk/hnumpy/compile.py @@ -6,7 +6,7 @@ from zamalang import CompilerEngine from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset from ..common.common_helpers import check_op_graph_is_integer_program -from ..common.compilation import CompilationArtifacts +from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import BaseValue from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter from ..common.mlir.utils import ( @@ -23,6 +23,7 @@ def compile_numpy_function_into_op_graph( function_to_trace: Callable, function_parameters: Dict[str, BaseValue], dataset: Iterator[Tuple[Any, ...]], + compilation_configuration: Optional[CompilationConfiguration] = None, compilation_artifacts: Optional[CompilationArtifacts] = None, ) -> OPGraph: """Compile a function into an OPGraph. @@ -34,18 +35,30 @@ def compile_numpy_function_into_op_graph( dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It needs to be an iterator on tuples which are of the same length than the number of parameters in the function, and in the same order than these same parameters + compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use + during compilation compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill during compilation Returns: OPGraph: compiled function into a graph """ + + # Create default configuration if custom configuration is not specified + compilation_configuration = ( + CompilationConfiguration() + if compilation_configuration is None + else compilation_configuration + ) + # Trace op_graph = trace_numpy_function(function_to_trace, function_parameters) - # Fuse float operations to have int to int ArbitraryFunction - if not check_op_graph_is_integer_program(op_graph): - fuse_float_operations(op_graph) + # Apply topological optimizations if they are enabled + if compilation_configuration.enable_topological_optimizations: + # Fuse float operations to have int to int ArbitraryFunction + if not check_op_graph_is_integer_program(op_graph): + fuse_float_operations(op_graph) # TODO: To be removed once we support more than integers offending_non_integer_nodes: List[ir.IntermediateNode] = [] @@ -82,6 +95,7 @@ def compile_numpy_function( function_to_trace: Callable, function_parameters: Dict[str, BaseValue], dataset: Iterator[Tuple[Any, ...]], + compilation_configuration: Optional[CompilationConfiguration] = None, compilation_artifacts: Optional[CompilationArtifacts] = None, ) -> CompilerEngine: """Main API of hnumpy, to be able to compile an homomorphic program. @@ -93,15 +107,29 @@ def compile_numpy_function( dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It needs to be an iterator on tuples which are of the same length than the number of parameters in the function, and in the same order than these same parameters + compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use + during compilation compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill during compilation Returns: CompilerEngine: engine to run and debug the compiled graph """ + + # Create default configuration if custom configuration is not specified + compilation_configuration = ( + CompilationConfiguration() + if compilation_configuration is None + else compilation_configuration + ) + # Compile into an OPGraph op_graph = compile_numpy_function_into_op_graph( - function_to_trace, function_parameters, dataset, compilation_artifacts + function_to_trace, + function_parameters, + dataset, + compilation_configuration, + compilation_artifacts, ) # Convert graph to an MLIR representation diff --git a/tests/common/compilation/test_artifacts.py b/tests/common/compilation/test_artifacts.py index 6fae76174..0e2c75697 100644 --- a/tests/common/compilation/test_artifacts.py +++ b/tests/common/compilation/test_artifacts.py @@ -20,7 +20,7 @@ def test_artifacts_export(): function, {"x": EncryptedValue(Integer(7, True))}, iter([(-2,), (-1,), (0,), (1,), (2,)]), - artifacts, + compilation_artifacts=artifacts, ) with tempfile.TemporaryDirectory() as tmp: diff --git a/tests/common/compilation/test_configuration.py b/tests/common/compilation/test_configuration.py new file mode 100644 index 000000000..0280fe2a7 --- /dev/null +++ b/tests/common/compilation/test_configuration.py @@ -0,0 +1,72 @@ +"""Test file for compilation configuration""" + +from inspect import signature + +import numpy +import pytest + +from hdk.common.compilation import CompilationConfiguration +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import EncryptedValue +from hdk.hnumpy.compile import compile_numpy_function_into_op_graph + + +def no_fuse(x): + """No fuse""" + return x + 2 + + +def simple_fuse_not_output(x): + """Simple fuse not output""" + intermediate = x.astype(numpy.float64) + intermediate = intermediate.astype(numpy.uint32) + return intermediate + 2 + + +@pytest.mark.parametrize( + "function_to_trace,fused", + [ + pytest.param( + no_fuse, + False, + id="no_fuse", + ), + pytest.param( + simple_fuse_not_output, + True, + id="simple_fuse_not_output", + marks=pytest.mark.xfail(strict=True), + # fails because it connot be compiled without topological optimizations + ), + ], +) +def test_enable_topological_optimizations(test_helpers, function_to_trace, fused): + """Test function for enable_topological_optimizations flag of compilation configuration""" + + op_graph = compile_numpy_function_into_op_graph( + function_to_trace, + { + param: EncryptedValue(Integer(32, is_signed=False)) + for param in signature(function_to_trace).parameters.keys() + }, + iter([(1,), (2,), (3,)]), + ) + op_graph_not_optimized = compile_numpy_function_into_op_graph( + function_to_trace, + { + param: EncryptedValue(Integer(32, is_signed=False)) + for param in signature(function_to_trace).parameters.keys() + }, + iter([(1,), (2,), (3,)]), + compilation_configuration=CompilationConfiguration(enable_topological_optimizations=False), + ) + + graph = op_graph.graph + not_optimized_graph = op_graph_not_optimized.graph + + if fused: + assert not test_helpers.digraphs_are_equivalent(graph, not_optimized_graph) + assert len(graph) < len(not_optimized_graph) + else: + assert test_helpers.digraphs_are_equivalent(graph, not_optimized_graph) + assert len(graph) == len(not_optimized_graph)