From 788e94bfa378b7f1795c9ff72987bda575fd9260 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 17 Aug 2021 14:43:01 +0100 Subject: [PATCH] feat: end to end compilation and execution --- .github/workflows/continuous-integration.yaml | 3 ++ benchmarks/test_compilation_and_evaluation.py | 6 +-- examples/QuantizedLinearRegression.ipynb | 4 +- examples/QuantizedLogisticRegression.ipynb | 4 +- hdk/hnumpy/compile.py | 47 +++++++++++++++++-- tests/common/compilation/test_artifacts.py | 4 +- tests/common/mlir/test_mlir_converter.py | 4 +- tests/hnumpy/test_compile.py | 44 +++++++++++++++-- 8 files changed, 96 insertions(+), 20 deletions(-) diff --git a/.github/workflows/continuous-integration.yaml b/.github/workflows/continuous-integration.yaml index 88081520c..f1261fe36 100644 --- a/.github/workflows/continuous-integration.yaml +++ b/.github/workflows/continuous-integration.yaml @@ -60,6 +60,9 @@ jobs: - name: PyTest id: pytest if: ${{ steps.conformance.outcome == 'success' && !cancelled() }} + env: + # TODO: remove this when concrete is statically linked with compiler + LD_PRELOAD: /concrete/target/release/libconcrete_ffi.so run: | make pytest - name: Notebooks diff --git a/benchmarks/test_compilation_and_evaluation.py b/benchmarks/test_compilation_and_evaluation.py index c44e5a396..a7e45d6b3 100644 --- a/benchmarks/test_compilation_and_evaluation.py +++ b/benchmarks/test_compilation_and_evaluation.py @@ -6,7 +6,7 @@ import pytest from hdk.common.data_types.integers import SignedInteger, UnsignedInteger from hdk.common.data_types.values import EncryptedValue -from hdk.hnumpy.compile import compile_numpy_function +from hdk.hnumpy.compile import compile_numpy_function_into_op_graph @pytest.mark.parametrize( @@ -35,7 +35,7 @@ def test_compilation(benchmark, function, parameters, ranges): @benchmark def compilation(): - compile_numpy_function(function, parameters, dataset(ranges)) + compile_numpy_function_into_op_graph(function, parameters, dataset(ranges)) @pytest.mark.parametrize( @@ -72,7 +72,7 @@ def test_evaluation(benchmark, function, parameters, ranges, inputs): for prod in itertools.product(*args): yield prod - graph = compile_numpy_function(function, parameters, dataset(ranges)) + graph = compile_numpy_function_into_op_graph(function, parameters, dataset(ranges)) @benchmark def evaluation(): diff --git a/examples/QuantizedLinearRegression.ipynb b/examples/QuantizedLinearRegression.ipynb index 1b9aae5b3..52e26535c 100644 --- a/examples/QuantizedLinearRegression.ipynb +++ b/examples/QuantizedLinearRegression.ipynb @@ -623,13 +623,13 @@ "source": [ "from hdk.common.data_types.integers import Integer\n", "from hdk.common.data_types.values import EncryptedValue\n", - "from hdk.hnumpy.compile import compile_numpy_function\n", + "from hdk.hnumpy.compile import compile_numpy_function_into_op_graph\n", "\n", "dataset = []\n", "for x_i in x_q:\n", " dataset.append((int(x_i[0]),))\n", "\n", - "homomorphic_model = compile_numpy_function(\n", + "homomorphic_model = compile_numpy_function_into_op_graph(\n", " infer,\n", " {\"x_0\": EncryptedValue(Integer(input_bits, is_signed=False))},\n", " iter(dataset),\n", diff --git a/examples/QuantizedLogisticRegression.ipynb b/examples/QuantizedLogisticRegression.ipynb index 37bac1437..1b1ef297f 100644 --- a/examples/QuantizedLogisticRegression.ipynb +++ b/examples/QuantizedLogisticRegression.ipynb @@ -727,13 +727,13 @@ "source": [ "from hdk.common.data_types.integers import Integer\n", "from hdk.common.data_types.values import EncryptedValue\n", - "from hdk.hnumpy.compile import compile_numpy_function\n", + "from hdk.hnumpy.compile import compile_numpy_function_into_op_graph\n", "\n", "dataset = []\n", "for x_i in x_q:\n", " dataset.append((int(x_i[0]), int(x_i[1])))\n", " \n", - "homomorphic_model = compile_numpy_function(\n", + "homomorphic_model = compile_numpy_function_into_op_graph(\n", " infer,\n", " {\n", " \"x_0\": EncryptedValue(Integer(input_bits, is_signed=False)),\n", diff --git a/hdk/hnumpy/compile.py b/hdk/hnumpy/compile.py index 153cc076c..03f899448 100644 --- a/hdk/hnumpy/compile.py +++ b/hdk/hnumpy/compile.py @@ -2,10 +2,13 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple +from zamalang import CompilerEngine + from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset from ..common.common_helpers import check_op_graph_is_integer_program from ..common.compilation import CompilationArtifacts from ..common.data_types import BaseValue +from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter from ..common.mlir.utils import ( is_graph_values_compatible_with_mlir, update_bit_width_for_mlir, @@ -16,13 +19,13 @@ from ..common.representation import intermediate as ir from ..hnumpy.tracing import trace_numpy_function -def compile_numpy_function( +def compile_numpy_function_into_op_graph( function_to_trace: Callable, function_parameters: Dict[str, BaseValue], dataset: Iterator[Tuple[Any, ...]], compilation_artifacts: Optional[CompilationArtifacts] = None, ) -> OPGraph: - """Main API of hnumpy, to be able to compile an homomorphic program. + """Compile a function into an OPGraph. Args: function_to_trace (Callable): The function you want to trace @@ -35,8 +38,7 @@ def compile_numpy_function( during compilation Returns: - OPGraph: currently returns a compilable graph, but later, it will return an MLIR compatible - with the compiler, and even later, it will return the result of the compilation + OPGraph: compiled function into a graph """ # Trace op_graph = trace_numpy_function(function_to_trace, function_parameters) @@ -74,3 +76,40 @@ def compile_numpy_function( compilation_artifacts.bounds = node_bounds return op_graph + + +def compile_numpy_function( + function_to_trace: Callable, + function_parameters: Dict[str, BaseValue], + dataset: Iterator[Tuple[Any, ...]], + compilation_artifacts: Optional[CompilationArtifacts] = None, +) -> CompilerEngine: + """Main API of hnumpy, to be able to compile an homomorphic program. + + Args: + function_to_trace (Callable): The function you want to trace + function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the + function is e.g. an EncryptedValue holding a 7bits unsigned Integer + dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It + needs to be an iterator on tuples which are of the same length than the number of + parameters in the function, and in the same order than these same parameters + compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill + during compilation + + Returns: + CompilerEngine: engine to run and debug the compiled graph + """ + # Compile into an OPGraph + op_graph = compile_numpy_function_into_op_graph( + function_to_trace, function_parameters, dataset, compilation_artifacts + ) + + # Convert graph to an MLIR representation + converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + mlir_result = converter.convert(op_graph) + + # Compile the MLIR representation + engine = CompilerEngine() + engine.compile_fhe(mlir_result) + + return engine diff --git a/tests/common/compilation/test_artifacts.py b/tests/common/compilation/test_artifacts.py index 5c0f6afc1..6fae76174 100644 --- a/tests/common/compilation/test_artifacts.py +++ b/tests/common/compilation/test_artifacts.py @@ -6,7 +6,7 @@ from pathlib import Path from hdk.common.compilation import CompilationArtifacts from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import EncryptedValue -from hdk.hnumpy.compile import compile_numpy_function +from hdk.hnumpy.compile import compile_numpy_function_into_op_graph def test_artifacts_export(): @@ -16,7 +16,7 @@ def test_artifacts_export(): return x + 42 artifacts = CompilationArtifacts() - compile_numpy_function( + compile_numpy_function_into_op_graph( function, {"x": EncryptedValue(Integer(7, True))}, iter([(-2,), (-1,), (0,), (1,), (2,)]), diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index 22528a808..617dd1c52 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -10,7 +10,7 @@ from zamalang.dialects import hlfhe from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import ClearValue, EncryptedValue from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter -from hdk.hnumpy.compile import compile_numpy_function +from hdk.hnumpy.compile import compile_numpy_function_into_op_graph def add(x, y): @@ -168,7 +168,7 @@ def datagen(*args): def test_mlir_converter(func, args_dict, args_ranges): """Test the conversion to MLIR by calling the parser from the compiler""" dataset = datagen(*args_ranges) - result_graph = compile_numpy_function(func, args_dict, dataset) + result_graph = compile_numpy_function_into_op_graph(func, args_dict, dataset) converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) mlir_result = converter.convert(result_graph) # testing that this doesn't raise an error diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py index e5a077980..f03b0e892 100644 --- a/tests/hnumpy/test_compile.py +++ b/tests/hnumpy/test_compile.py @@ -1,5 +1,6 @@ """Test file for hnumpy compilation functions""" import itertools +import random import numpy import pytest @@ -8,7 +9,10 @@ from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import EncryptedValue from hdk.common.debugging import draw_graph, get_printable_graph from hdk.common.extensions.table import LookupTable -from hdk.hnumpy.compile import compile_numpy_function +from hdk.hnumpy.compile import ( + compile_numpy_function, + compile_numpy_function_into_op_graph, +) def no_fuse_unhandled(x, y): @@ -49,7 +53,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names } - op_graph = compile_numpy_function( + op_graph = compile_numpy_function_into_op_graph( function, function_parameters, data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), @@ -63,6 +67,36 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n print(f"\n{str_of_the_graph}\n") +@pytest.mark.parametrize( + "function,input_ranges,list_of_arg_names", + [ + pytest.param(lambda x: x + 42, ((0, 2),), ["x"]), + pytest.param(lambda x: x * 2, ((0, 2),), ["x"]), + pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]), + pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]), + ], +) +def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names): + """Test function compile_numpy_function for a program with multiple outputs""" + + def data_gen(args): + for prod in itertools.product(*args): + yield prod + + function_parameters = { + arg_name: EncryptedValue(Integer(64, False)) for arg_name in list_of_arg_names + } + + compiler_engine = compile_numpy_function( + function, + function_parameters, + data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + ) + + args = [random.randint(low, high) for (low, high) in input_ranges] + compiler_engine.run(*args) + + def test_compile_function_with_direct_tlu(): """Test compile_numpy_function for a program with direct table lookup""" @@ -71,7 +105,7 @@ def test_compile_function_with_direct_tlu(): def function(x): return x + table[x] - op_graph = compile_numpy_function( + op_graph = compile_numpy_function_into_op_graph( function, {"x": EncryptedValue(Integer(2, is_signed=False))}, iter([(0,), (1,), (2,), (3,)]), @@ -90,7 +124,7 @@ def test_compile_function_with_direct_tlu_overflow(): return table[x] with pytest.raises(ValueError): - compile_numpy_function( + compile_numpy_function_into_op_graph( function, {"x": EncryptedValue(Integer(3, is_signed=False))}, iter([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]), @@ -115,7 +149,7 @@ def test_fail_compile(function, input_ranges, list_of_arg_names): } with pytest.raises(TypeError, match=r"signed integers aren't supported for MLIR lowering"): - compile_numpy_function( + compile_numpy_function_into_op_graph( function, function_parameters, data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),