feat: end to end compilation and execution

This commit is contained in:
youben11
2021-08-17 14:43:01 +01:00
committed by Ayoub Benaissa
parent 4e40982f5a
commit 788e94bfa3
8 changed files with 96 additions and 20 deletions

View File

@@ -60,6 +60,9 @@ jobs:
- name: PyTest
id: pytest
if: ${{ steps.conformance.outcome == 'success' && !cancelled() }}
env:
# TODO: remove this when concrete is statically linked with compiler
LD_PRELOAD: /concrete/target/release/libconcrete_ffi.so
run: |
make pytest
- name: Notebooks

View File

@@ -6,7 +6,7 @@ import pytest
from hdk.common.data_types.integers import SignedInteger, UnsignedInteger
from hdk.common.data_types.values import EncryptedValue
from hdk.hnumpy.compile import compile_numpy_function
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
@pytest.mark.parametrize(
@@ -35,7 +35,7 @@ def test_compilation(benchmark, function, parameters, ranges):
@benchmark
def compilation():
compile_numpy_function(function, parameters, dataset(ranges))
compile_numpy_function_into_op_graph(function, parameters, dataset(ranges))
@pytest.mark.parametrize(
@@ -72,7 +72,7 @@ def test_evaluation(benchmark, function, parameters, ranges, inputs):
for prod in itertools.product(*args):
yield prod
graph = compile_numpy_function(function, parameters, dataset(ranges))
graph = compile_numpy_function_into_op_graph(function, parameters, dataset(ranges))
@benchmark
def evaluation():

View File

@@ -623,13 +623,13 @@
"source": [
"from hdk.common.data_types.integers import Integer\n",
"from hdk.common.data_types.values import EncryptedValue\n",
"from hdk.hnumpy.compile import compile_numpy_function\n",
"from hdk.hnumpy.compile import compile_numpy_function_into_op_graph\n",
"\n",
"dataset = []\n",
"for x_i in x_q:\n",
" dataset.append((int(x_i[0]),))\n",
"\n",
"homomorphic_model = compile_numpy_function(\n",
"homomorphic_model = compile_numpy_function_into_op_graph(\n",
" infer,\n",
" {\"x_0\": EncryptedValue(Integer(input_bits, is_signed=False))},\n",
" iter(dataset),\n",

View File

@@ -727,13 +727,13 @@
"source": [
"from hdk.common.data_types.integers import Integer\n",
"from hdk.common.data_types.values import EncryptedValue\n",
"from hdk.hnumpy.compile import compile_numpy_function\n",
"from hdk.hnumpy.compile import compile_numpy_function_into_op_graph\n",
"\n",
"dataset = []\n",
"for x_i in x_q:\n",
" dataset.append((int(x_i[0]), int(x_i[1])))\n",
" \n",
"homomorphic_model = compile_numpy_function(\n",
"homomorphic_model = compile_numpy_function_into_op_graph(\n",
" infer,\n",
" {\n",
" \"x_0\": EncryptedValue(Integer(input_bits, is_signed=False)),\n",

View File

@@ -2,10 +2,13 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from zamalang import CompilerEngine
from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
from ..common.common_helpers import check_op_graph_is_integer_program
from ..common.compilation import CompilationArtifacts
from ..common.data_types import BaseValue
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from ..common.mlir.utils import (
is_graph_values_compatible_with_mlir,
update_bit_width_for_mlir,
@@ -16,13 +19,13 @@ from ..common.representation import intermediate as ir
from ..hnumpy.tracing import trace_numpy_function
def compile_numpy_function(
def compile_numpy_function_into_op_graph(
function_to_trace: Callable,
function_parameters: Dict[str, BaseValue],
dataset: Iterator[Tuple[Any, ...]],
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> OPGraph:
"""Main API of hnumpy, to be able to compile an homomorphic program.
"""Compile a function into an OPGraph.
Args:
function_to_trace (Callable): The function you want to trace
@@ -35,8 +38,7 @@ def compile_numpy_function(
during compilation
Returns:
OPGraph: currently returns a compilable graph, but later, it will return an MLIR compatible
with the compiler, and even later, it will return the result of the compilation
OPGraph: compiled function into a graph
"""
# Trace
op_graph = trace_numpy_function(function_to_trace, function_parameters)
@@ -74,3 +76,40 @@ def compile_numpy_function(
compilation_artifacts.bounds = node_bounds
return op_graph
def compile_numpy_function(
function_to_trace: Callable,
function_parameters: Dict[str, BaseValue],
dataset: Iterator[Tuple[Any, ...]],
compilation_artifacts: Optional[CompilationArtifacts] = None,
) -> CompilerEngine:
"""Main API of hnumpy, to be able to compile an homomorphic program.
Args:
function_to_trace (Callable): The function you want to trace
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedValue holding a 7bits unsigned Integer
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
Returns:
CompilerEngine: engine to run and debug the compiled graph
"""
# Compile into an OPGraph
op_graph = compile_numpy_function_into_op_graph(
function_to_trace, function_parameters, dataset, compilation_artifacts
)
# Convert graph to an MLIR representation
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(op_graph)
# Compile the MLIR representation
engine = CompilerEngine()
engine.compile_fhe(mlir_result)
return engine

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from hdk.common.compilation import CompilationArtifacts
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.hnumpy.compile import compile_numpy_function
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
def test_artifacts_export():
@@ -16,7 +16,7 @@ def test_artifacts_export():
return x + 42
artifacts = CompilationArtifacts()
compile_numpy_function(
compile_numpy_function_into_op_graph(
function,
{"x": EncryptedValue(Integer(7, True))},
iter([(-2,), (-1,), (0,), (1,), (2,)]),

View File

@@ -10,7 +10,7 @@ from zamalang.dialects import hlfhe
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from hdk.hnumpy.compile import compile_numpy_function
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
def add(x, y):
@@ -168,7 +168,7 @@ def datagen(*args):
def test_mlir_converter(func, args_dict, args_ranges):
"""Test the conversion to MLIR by calling the parser from the compiler"""
dataset = datagen(*args_ranges)
result_graph = compile_numpy_function(func, args_dict, dataset)
result_graph = compile_numpy_function_into_op_graph(func, args_dict, dataset)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error

View File

@@ -1,5 +1,6 @@
"""Test file for hnumpy compilation functions"""
import itertools
import random
import numpy
import pytest
@@ -8,7 +9,10 @@ from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.common.extensions.table import LookupTable
from hdk.hnumpy.compile import compile_numpy_function
from hdk.hnumpy.compile import (
compile_numpy_function,
compile_numpy_function_into_op_graph,
)
def no_fuse_unhandled(x, y):
@@ -49,7 +53,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names
}
op_graph = compile_numpy_function(
op_graph = compile_numpy_function_into_op_graph(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
@@ -63,6 +67,36 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
print(f"\n{str_of_the_graph}\n")
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 42, ((0, 2),), ["x"]),
pytest.param(lambda x: x * 2, ((0, 2),), ["x"]),
pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
],
)
def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):
"""Test function compile_numpy_function for a program with multiple outputs"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedValue(Integer(64, False)) for arg_name in list_of_arg_names
}
compiler_engine = compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)
args = [random.randint(low, high) for (low, high) in input_ranges]
compiler_engine.run(*args)
def test_compile_function_with_direct_tlu():
"""Test compile_numpy_function for a program with direct table lookup"""
@@ -71,7 +105,7 @@ def test_compile_function_with_direct_tlu():
def function(x):
return x + table[x]
op_graph = compile_numpy_function(
op_graph = compile_numpy_function_into_op_graph(
function,
{"x": EncryptedValue(Integer(2, is_signed=False))},
iter([(0,), (1,), (2,), (3,)]),
@@ -90,7 +124,7 @@ def test_compile_function_with_direct_tlu_overflow():
return table[x]
with pytest.raises(ValueError):
compile_numpy_function(
compile_numpy_function_into_op_graph(
function,
{"x": EncryptedValue(Integer(3, is_signed=False))},
iter([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]),
@@ -115,7 +149,7 @@ def test_fail_compile(function, input_ranges, list_of_arg_names):
}
with pytest.raises(TypeError, match=r"signed integers aren't supported for MLIR lowering"):
compile_numpy_function(
compile_numpy_function_into_op_graph(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),