feat(compilation-configuration): make compilation customizable

This commit is contained in:
Umut
2021-08-16 16:57:15 +03:00
parent 1b33cd7307
commit a367d68c6e
5 changed files with 120 additions and 6 deletions

View File

@@ -1,3 +1,4 @@
"""Module for compilation related types."""
from .artifacts import CompilationArtifacts
from .configuration import CompilationConfiguration

View File

@@ -0,0 +1,13 @@
"""Module for compilation configuration."""
class CompilationConfiguration:
"""Class that allows the compilation process to be customized."""
enable_topological_optimizations: bool
def __init__(
self,
enable_topological_optimizations: bool = True,
):
self.enable_topological_optimizations = enable_topological_optimizations

View File

@@ -6,7 +6,7 @@ from zamalang import CompilerEngine
from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
from ..common.common_helpers import check_op_graph_is_integer_program
from ..common.compilation import CompilationArtifacts
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import BaseValue
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from ..common.mlir.utils import (
@@ -23,6 +23,7 @@ def compile_numpy_function_into_op_graph(
function_to_trace: Callable,
function_parameters: Dict[str, BaseValue],
dataset: Iterator[Tuple[Any, ...]],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> OPGraph:
"""Compile a function into an OPGraph.
@@ -34,18 +35,30 @@ def compile_numpy_function_into_op_graph(
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
Returns:
OPGraph: compiled function into a graph
"""
# Create default configuration if custom configuration is not specified
compilation_configuration = (
CompilationConfiguration()
if compilation_configuration is None
else compilation_configuration
)
# Trace
op_graph = trace_numpy_function(function_to_trace, function_parameters)
# Fuse float operations to have int to int ArbitraryFunction
if not check_op_graph_is_integer_program(op_graph):
fuse_float_operations(op_graph)
# Apply topological optimizations if they are enabled
if compilation_configuration.enable_topological_optimizations:
# Fuse float operations to have int to int ArbitraryFunction
if not check_op_graph_is_integer_program(op_graph):
fuse_float_operations(op_graph)
# TODO: To be removed once we support more than integers
offending_non_integer_nodes: List[ir.IntermediateNode] = []
@@ -82,6 +95,7 @@ def compile_numpy_function(
function_to_trace: Callable,
function_parameters: Dict[str, BaseValue],
dataset: Iterator[Tuple[Any, ...]],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> CompilerEngine:
"""Main API of hnumpy, to be able to compile an homomorphic program.
@@ -93,15 +107,29 @@ def compile_numpy_function(
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
Returns:
CompilerEngine: engine to run and debug the compiled graph
"""
# Create default configuration if custom configuration is not specified
compilation_configuration = (
CompilationConfiguration()
if compilation_configuration is None
else compilation_configuration
)
# Compile into an OPGraph
op_graph = compile_numpy_function_into_op_graph(
function_to_trace, function_parameters, dataset, compilation_artifacts
function_to_trace,
function_parameters,
dataset,
compilation_configuration,
compilation_artifacts,
)
# Convert graph to an MLIR representation

View File

@@ -20,7 +20,7 @@ def test_artifacts_export():
function,
{"x": EncryptedValue(Integer(7, True))},
iter([(-2,), (-1,), (0,), (1,), (2,)]),
artifacts,
compilation_artifacts=artifacts,
)
with tempfile.TemporaryDirectory() as tmp:

View File

@@ -0,0 +1,72 @@
"""Test file for compilation configuration"""
from inspect import signature
import numpy
import pytest
from hdk.common.compilation import CompilationConfiguration
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
def no_fuse(x):
"""No fuse"""
return x + 2
def simple_fuse_not_output(x):
"""Simple fuse not output"""
intermediate = x.astype(numpy.float64)
intermediate = intermediate.astype(numpy.uint32)
return intermediate + 2
@pytest.mark.parametrize(
"function_to_trace,fused",
[
pytest.param(
no_fuse,
False,
id="no_fuse",
),
pytest.param(
simple_fuse_not_output,
True,
id="simple_fuse_not_output",
marks=pytest.mark.xfail(strict=True),
# fails because it connot be compiled without topological optimizations
),
],
)
def test_enable_topological_optimizations(test_helpers, function_to_trace, fused):
"""Test function for enable_topological_optimizations flag of compilation configuration"""
op_graph = compile_numpy_function_into_op_graph(
function_to_trace,
{
param: EncryptedValue(Integer(32, is_signed=False))
for param in signature(function_to_trace).parameters.keys()
},
iter([(1,), (2,), (3,)]),
)
op_graph_not_optimized = compile_numpy_function_into_op_graph(
function_to_trace,
{
param: EncryptedValue(Integer(32, is_signed=False))
for param in signature(function_to_trace).parameters.keys()
},
iter([(1,), (2,), (3,)]),
compilation_configuration=CompilationConfiguration(enable_topological_optimizations=False),
)
graph = op_graph.graph
not_optimized_graph = op_graph_not_optimized.graph
if fused:
assert not test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
assert len(graph) < len(not_optimized_graph)
else:
assert test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
assert len(graph) == len(not_optimized_graph)