mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: create torch-like API part deux
- add function to get the FHECircuit closes #233
This commit is contained in:
@@ -611,6 +611,79 @@ def prepare_op_graph_for_mlir(op_graph: OPGraph):
|
||||
hack_offset_negative_inputs_to_lookup_tables(op_graph)
|
||||
|
||||
|
||||
def _compile_op_graph_to_fhe_circuit_internal(
|
||||
op_graph: OPGraph, show_mlir: bool, compilation_artifacts: CompilationArtifacts
|
||||
) -> FHECircuit:
|
||||
"""Compile the OPGraph to an FHECircuit.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph to compile.
|
||||
show_mlir (bool): determine whether we print the mlir string.
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
FHECircuit: the compiled FHECircuit
|
||||
"""
|
||||
prepare_op_graph_for_mlir(op_graph)
|
||||
|
||||
# Convert graph to an MLIR representation
|
||||
converter = NPMLIRConverter()
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
# Show MLIR representation if requested
|
||||
if show_mlir:
|
||||
print(f"MLIR which is going to be compiled: \n{mlir_result}")
|
||||
|
||||
# Add MLIR representation as an artifact
|
||||
compilation_artifacts.add_final_operation_graph_mlir(mlir_result)
|
||||
|
||||
# Compile the MLIR representation
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_result)
|
||||
|
||||
return FHECircuit(op_graph, engine)
|
||||
|
||||
|
||||
def compile_op_graph_to_fhe_circuit(
|
||||
op_graph: OPGraph,
|
||||
show_mlir: bool,
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> FHECircuit:
|
||||
"""Compile the OPGraph to an FHECircuit.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph to compile.
|
||||
show_mlir (bool): determine whether we print the mlir string.
|
||||
compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
FHECircuit: the compiled circuit and the compiled FHECircuit
|
||||
"""
|
||||
|
||||
(
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
) = sanitize_compilation_configuration_and_artifacts(
|
||||
compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
def compilation_function():
|
||||
return _compile_op_graph_to_fhe_circuit_internal(op_graph, show_mlir, compilation_artifacts)
|
||||
|
||||
result = run_compilation_function_with_error_management(
|
||||
compilation_function, compilation_configuration, compilation_artifacts
|
||||
)
|
||||
|
||||
# for mypy
|
||||
assert isinstance(result, FHECircuit)
|
||||
return result
|
||||
|
||||
|
||||
def _compile_numpy_function_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
@@ -648,24 +721,11 @@ def _compile_numpy_function_internal(
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
prepare_op_graph_for_mlir(op_graph)
|
||||
fhe_circuit = _compile_op_graph_to_fhe_circuit_internal(
|
||||
op_graph, show_mlir, compilation_artifacts
|
||||
)
|
||||
|
||||
# Convert graph to an MLIR representation
|
||||
converter = NPMLIRConverter()
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
# Show MLIR representation if requested
|
||||
if show_mlir:
|
||||
print(f"MLIR which is going to be compiled: \n{mlir_result}")
|
||||
|
||||
# Add MLIR representation as an artifact
|
||||
compilation_artifacts.add_final_operation_graph_mlir(mlir_result)
|
||||
|
||||
# Compile the MLIR representation
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_result)
|
||||
|
||||
return FHECircuit(op_graph, engine)
|
||||
return fhe_circuit
|
||||
|
||||
|
||||
def compile_numpy_function(
|
||||
|
||||
@@ -6,10 +6,15 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import IntermediateNode
|
||||
from ..common.values import BaseValue
|
||||
from .compile import compile_numpy_function_into_op_graph, measure_op_graph_bounds_and_update
|
||||
from .compile import (
|
||||
compile_numpy_function_into_op_graph,
|
||||
compile_op_graph_to_fhe_circuit,
|
||||
measure_op_graph_bounds_and_update,
|
||||
)
|
||||
from .np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data
|
||||
|
||||
|
||||
@@ -197,3 +202,32 @@ class NPFHECompiler:
|
||||
if isinstance(dtype := (input_.dtype), Integer):
|
||||
dtype.bit_width = 128
|
||||
dtype.is_signed = True
|
||||
|
||||
def get_compiled_fhe_circuit(self, show_mlir: bool = False) -> FHECircuit:
|
||||
"""Return a compiled FHECircuit if the instance was evaluated on an inputset.
|
||||
|
||||
Args:
|
||||
show_mlir (bool, optional): if set, the MLIR produced by the converter and which is
|
||||
going to be sent to the compiler backend is shown on the screen, e.g., for debugging
|
||||
or demo. Defaults to False.
|
||||
|
||||
Raises:
|
||||
RuntimeError: raised if no inputset was passed to the instance.
|
||||
|
||||
Returns:
|
||||
FHECircuit: the compiled FHECircuit
|
||||
"""
|
||||
self._eval_on_current_inputset()
|
||||
|
||||
if self._op_graph is None:
|
||||
raise RuntimeError(
|
||||
"Requested FHECircuit but no OPGraph was compiled. "
|
||||
f"Did you forget to evaluate {self.__class__.__name__} over an inputset?"
|
||||
)
|
||||
|
||||
return compile_op_graph_to_fhe_circuit(
|
||||
self._op_graph,
|
||||
show_mlir,
|
||||
self.compilation_configuration,
|
||||
self.compilation_artifacts,
|
||||
)
|
||||
|
||||
@@ -22,17 +22,20 @@ def complicated_topology(x, y):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)])
|
||||
def test_np_fhe_compiler(input_shape, default_compilation_configuration):
|
||||
def test_np_fhe_compiler_op_graph(input_shape, default_compilation_configuration):
|
||||
"""Test NPFHECompiler in two subtests."""
|
||||
subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration)
|
||||
subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration)
|
||||
subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration)
|
||||
subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration)
|
||||
|
||||
|
||||
def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration):
|
||||
def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration):
|
||||
"""test for NPFHECompiler on one input function"""
|
||||
|
||||
def function_to_compile(x):
|
||||
return complicated_topology(x, 0)
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
lambda x: complicated_topology(x, 0),
|
||||
function_to_compile,
|
||||
{"x": "encrypted"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
@@ -45,7 +48,7 @@ def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configurati
|
||||
|
||||
for i in numpy.arange(5):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
assert numpy.array_equal(compiler(i), complicated_topology(i, 0))
|
||||
assert numpy.array_equal(compiler(i), function_to_compile(i))
|
||||
|
||||
# For coverage, check that we flush the inputset when we query the OPGraph
|
||||
current_op_graph = compiler.op_graph
|
||||
@@ -57,7 +60,7 @@ def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configurati
|
||||
# Continue a bit more
|
||||
for i in numpy.arange(5, 10):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
assert numpy.array_equal(compiler(i), complicated_topology(i, 0))
|
||||
assert numpy.array_equal(compiler(i), function_to_compile(i))
|
||||
|
||||
if input_shape == ():
|
||||
assert (
|
||||
@@ -103,7 +106,7 @@ def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configurati
|
||||
), got
|
||||
|
||||
|
||||
def subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration):
|
||||
def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration):
|
||||
"""test for NPFHECompiler on two inputs function"""
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
@@ -199,17 +202,51 @@ def test_np_fhe_compiler_auto_flush(
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test the auto flush of NPFHECompiler once the inputset is 128 elements."""
|
||||
|
||||
def function_to_compile(x):
|
||||
return x // 2
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
lambda x: x // 2,
|
||||
function_to_compile,
|
||||
{"x": "encrypted"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
for i in numpy.arange(inputset_len):
|
||||
assert numpy.array_equal(compiler(i), i // 2)
|
||||
assert numpy.array_equal(compiler(i), function_to_compile(i))
|
||||
|
||||
# Check the inputset was properly flushed
|
||||
assert (
|
||||
len(compiler._current_inputset) # pylint: disable=protected-access
|
||||
== expected_remaining_inputset_len
|
||||
)
|
||||
|
||||
|
||||
def test_np_fhe_compiler_full_compilation(default_compilation_configuration):
|
||||
"""Test the case where we generate an FHE circuit."""
|
||||
|
||||
def function_to_compile(x):
|
||||
return x + 42
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
function_to_compile,
|
||||
{"x": "encrypted"},
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# For coverage
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compiler.get_compiled_fhe_circuit()
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Requested FHECircuit but no OPGraph was compiled. "
|
||||
"Did you forget to evaluate NPFHECompiler over an inputset?"
|
||||
)
|
||||
|
||||
for i in numpy.arange(64):
|
||||
assert numpy.array_equal(compiler(i), function_to_compile(i))
|
||||
|
||||
fhe_circuit = compiler.get_compiled_fhe_circuit()
|
||||
|
||||
for i in range(64):
|
||||
assert fhe_circuit.run(i) == function_to_compile(i)
|
||||
|
||||
Reference in New Issue
Block a user