feat: create torch-like API part deux

- add function to get the FHECircuit

closes #233
This commit is contained in:
Arthur Meyre
2021-11-30 14:09:09 +01:00
parent faa6dd4403
commit 8b0a793cda
3 changed files with 159 additions and 28 deletions

View File

@@ -611,6 +611,79 @@ def prepare_op_graph_for_mlir(op_graph: OPGraph):
hack_offset_negative_inputs_to_lookup_tables(op_graph)
def _compile_op_graph_to_fhe_circuit_internal(
op_graph: OPGraph, show_mlir: bool, compilation_artifacts: CompilationArtifacts
) -> FHECircuit:
"""Compile the OPGraph to an FHECircuit.
Args:
op_graph (OPGraph): the OPGraph to compile.
show_mlir (bool): determine whether we print the mlir string.
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
Returns:
FHECircuit: the compiled FHECircuit
"""
prepare_op_graph_for_mlir(op_graph)
# Convert graph to an MLIR representation
converter = NPMLIRConverter()
mlir_result = converter.convert(op_graph)
# Show MLIR representation if requested
if show_mlir:
print(f"MLIR which is going to be compiled: \n{mlir_result}")
# Add MLIR representation as an artifact
compilation_artifacts.add_final_operation_graph_mlir(mlir_result)
# Compile the MLIR representation
engine = CompilerEngine()
engine.compile_fhe(mlir_result)
return FHECircuit(op_graph, engine)
def compile_op_graph_to_fhe_circuit(
op_graph: OPGraph,
show_mlir: bool,
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> FHECircuit:
"""Compile the OPGraph to an FHECircuit.
Args:
op_graph (OPGraph): the OPGraph to compile.
show_mlir (bool): determine whether we print the mlir string.
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
Returns:
FHECircuit: the compiled circuit and the compiled FHECircuit
"""
(
compilation_configuration,
compilation_artifacts,
) = sanitize_compilation_configuration_and_artifacts(
compilation_configuration, compilation_artifacts
)
def compilation_function():
return _compile_op_graph_to_fhe_circuit_internal(op_graph, show_mlir, compilation_artifacts)
result = run_compilation_function_with_error_management(
compilation_function, compilation_configuration, compilation_artifacts
)
# for mypy
assert isinstance(result, FHECircuit)
return result
def _compile_numpy_function_internal(
function_to_compile: Callable,
function_parameters: Dict[str, BaseValue],
@@ -648,24 +721,11 @@ def _compile_numpy_function_internal(
compilation_artifacts,
)
prepare_op_graph_for_mlir(op_graph)
fhe_circuit = _compile_op_graph_to_fhe_circuit_internal(
op_graph, show_mlir, compilation_artifacts
)
# Convert graph to an MLIR representation
converter = NPMLIRConverter()
mlir_result = converter.convert(op_graph)
# Show MLIR representation if requested
if show_mlir:
print(f"MLIR which is going to be compiled: \n{mlir_result}")
# Add MLIR representation as an artifact
compilation_artifacts.add_final_operation_graph_mlir(mlir_result)
# Compile the MLIR representation
engine = CompilerEngine()
engine.compile_fhe(mlir_result)
return FHECircuit(op_graph, engine)
return fhe_circuit
def compile_numpy_function(

View File

@@ -6,10 +6,15 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.fhe_circuit import FHECircuit
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import IntermediateNode
from ..common.values import BaseValue
from .compile import compile_numpy_function_into_op_graph, measure_op_graph_bounds_and_update
from .compile import (
compile_numpy_function_into_op_graph,
compile_op_graph_to_fhe_circuit,
measure_op_graph_bounds_and_update,
)
from .np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
@@ -197,3 +202,32 @@ class NPFHECompiler:
if isinstance(dtype := (input_.dtype), Integer):
dtype.bit_width = 128
dtype.is_signed = True
def get_compiled_fhe_circuit(self, show_mlir: bool = False) -> FHECircuit:
"""Return a compiled FHECircuit if the instance was evaluated on an inputset.
Args:
show_mlir (bool, optional): if set, the MLIR produced by the converter and which is
going to be sent to the compiler backend is shown on the screen, e.g., for debugging
or demo. Defaults to False.
Raises:
RuntimeError: raised if no inputset was passed to the instance.
Returns:
FHECircuit: the compiled FHECircuit
"""
self._eval_on_current_inputset()
if self._op_graph is None:
raise RuntimeError(
"Requested FHECircuit but no OPGraph was compiled. "
f"Did you forget to evaluate {self.__class__.__name__} over an inputset?"
)
return compile_op_graph_to_fhe_circuit(
self._op_graph,
show_mlir,
self.compilation_configuration,
self.compilation_artifacts,
)

View File

@@ -22,17 +22,20 @@ def complicated_topology(x, y):
@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)])
def test_np_fhe_compiler(input_shape, default_compilation_configuration):
def test_np_fhe_compiler_op_graph(input_shape, default_compilation_configuration):
"""Test NPFHECompiler in two subtests."""
subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration)
subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration)
subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration)
subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration)
def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration):
def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration):
"""test for NPFHECompiler on one input function"""
def function_to_compile(x):
return complicated_topology(x, 0)
compiler = NPFHECompiler(
lambda x: complicated_topology(x, 0),
function_to_compile,
{"x": "encrypted"},
default_compilation_configuration,
)
@@ -45,7 +48,7 @@ def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configurati
for i in numpy.arange(5):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
assert numpy.array_equal(compiler(i), complicated_topology(i, 0))
assert numpy.array_equal(compiler(i), function_to_compile(i))
# For coverage, check that we flush the inputset when we query the OPGraph
current_op_graph = compiler.op_graph
@@ -57,7 +60,7 @@ def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configurati
# Continue a bit more
for i in numpy.arange(5, 10):
i = numpy.ones(input_shape, dtype=numpy.int64) * i
assert numpy.array_equal(compiler(i), complicated_topology(i, 0))
assert numpy.array_equal(compiler(i), function_to_compile(i))
if input_shape == ():
assert (
@@ -103,7 +106,7 @@ def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configurati
), got
def subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration):
def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration):
"""test for NPFHECompiler on two inputs function"""
compiler = NPFHECompiler(
@@ -199,17 +202,51 @@ def test_np_fhe_compiler_auto_flush(
default_compilation_configuration,
):
"""Test the auto flush of NPFHECompiler once the inputset is 128 elements."""
def function_to_compile(x):
return x // 2
compiler = NPFHECompiler(
lambda x: x // 2,
function_to_compile,
{"x": "encrypted"},
default_compilation_configuration,
)
for i in numpy.arange(inputset_len):
assert numpy.array_equal(compiler(i), i // 2)
assert numpy.array_equal(compiler(i), function_to_compile(i))
# Check the inputset was properly flushed
assert (
len(compiler._current_inputset) # pylint: disable=protected-access
== expected_remaining_inputset_len
)
def test_np_fhe_compiler_full_compilation(default_compilation_configuration):
"""Test the case where we generate an FHE circuit."""
def function_to_compile(x):
return x + 42
compiler = NPFHECompiler(
function_to_compile,
{"x": "encrypted"},
default_compilation_configuration,
)
# For coverage
with pytest.raises(RuntimeError) as excinfo:
compiler.get_compiled_fhe_circuit()
assert str(excinfo.value) == (
"Requested FHECircuit but no OPGraph was compiled. "
"Did you forget to evaluate NPFHECompiler over an inputset?"
)
for i in numpy.arange(64):
assert numpy.array_equal(compiler(i), function_to_compile(i))
fhe_circuit = compiler.get_compiled_fhe_circuit()
for i in range(64):
assert fhe_circuit.run(i) == function_to_compile(i)