mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat: end to end compilation and execution
This commit is contained in:
@@ -60,6 +60,9 @@ jobs:
|
||||
- name: PyTest
|
||||
id: pytest
|
||||
if: ${{ steps.conformance.outcome == 'success' && !cancelled() }}
|
||||
env:
|
||||
# TODO: remove this when concrete is statically linked with compiler
|
||||
LD_PRELOAD: /concrete/target/release/libconcrete_ffi.so
|
||||
run: |
|
||||
make pytest
|
||||
- name: Notebooks
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
|
||||
from hdk.common.data_types.integers import SignedInteger, UnsignedInteger
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.hnumpy.compile import compile_numpy_function
|
||||
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -35,7 +35,7 @@ def test_compilation(benchmark, function, parameters, ranges):
|
||||
|
||||
@benchmark
|
||||
def compilation():
|
||||
compile_numpy_function(function, parameters, dataset(ranges))
|
||||
compile_numpy_function_into_op_graph(function, parameters, dataset(ranges))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -72,7 +72,7 @@ def test_evaluation(benchmark, function, parameters, ranges, inputs):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
graph = compile_numpy_function(function, parameters, dataset(ranges))
|
||||
graph = compile_numpy_function_into_op_graph(function, parameters, dataset(ranges))
|
||||
|
||||
@benchmark
|
||||
def evaluation():
|
||||
|
||||
@@ -623,13 +623,13 @@
|
||||
"source": [
|
||||
"from hdk.common.data_types.integers import Integer\n",
|
||||
"from hdk.common.data_types.values import EncryptedValue\n",
|
||||
"from hdk.hnumpy.compile import compile_numpy_function\n",
|
||||
"from hdk.hnumpy.compile import compile_numpy_function_into_op_graph\n",
|
||||
"\n",
|
||||
"dataset = []\n",
|
||||
"for x_i in x_q:\n",
|
||||
" dataset.append((int(x_i[0]),))\n",
|
||||
"\n",
|
||||
"homomorphic_model = compile_numpy_function(\n",
|
||||
"homomorphic_model = compile_numpy_function_into_op_graph(\n",
|
||||
" infer,\n",
|
||||
" {\"x_0\": EncryptedValue(Integer(input_bits, is_signed=False))},\n",
|
||||
" iter(dataset),\n",
|
||||
|
||||
@@ -727,13 +727,13 @@
|
||||
"source": [
|
||||
"from hdk.common.data_types.integers import Integer\n",
|
||||
"from hdk.common.data_types.values import EncryptedValue\n",
|
||||
"from hdk.hnumpy.compile import compile_numpy_function\n",
|
||||
"from hdk.hnumpy.compile import compile_numpy_function_into_op_graph\n",
|
||||
"\n",
|
||||
"dataset = []\n",
|
||||
"for x_i in x_q:\n",
|
||||
" dataset.append((int(x_i[0]), int(x_i[1])))\n",
|
||||
" \n",
|
||||
"homomorphic_model = compile_numpy_function(\n",
|
||||
"homomorphic_model = compile_numpy_function_into_op_graph(\n",
|
||||
" infer,\n",
|
||||
" {\n",
|
||||
" \"x_0\": EncryptedValue(Integer(input_bits, is_signed=False)),\n",
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from zamalang import CompilerEngine
|
||||
|
||||
from ..common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
|
||||
from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from ..common.mlir.utils import (
|
||||
is_graph_values_compatible_with_mlir,
|
||||
update_bit_width_for_mlir,
|
||||
@@ -16,13 +19,13 @@ from ..common.representation import intermediate as ir
|
||||
from ..hnumpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
def compile_numpy_function(
|
||||
def compile_numpy_function_into_op_graph(
|
||||
function_to_trace: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
dataset: Iterator[Tuple[Any, ...]],
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> OPGraph:
|
||||
"""Main API of hnumpy, to be able to compile an homomorphic program.
|
||||
"""Compile a function into an OPGraph.
|
||||
|
||||
Args:
|
||||
function_to_trace (Callable): The function you want to trace
|
||||
@@ -35,8 +38,7 @@ def compile_numpy_function(
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: currently returns a compilable graph, but later, it will return an MLIR compatible
|
||||
with the compiler, and even later, it will return the result of the compilation
|
||||
OPGraph: compiled function into a graph
|
||||
"""
|
||||
# Trace
|
||||
op_graph = trace_numpy_function(function_to_trace, function_parameters)
|
||||
@@ -74,3 +76,40 @@ def compile_numpy_function(
|
||||
compilation_artifacts.bounds = node_bounds
|
||||
|
||||
return op_graph
|
||||
|
||||
|
||||
def compile_numpy_function(
|
||||
function_to_trace: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
dataset: Iterator[Tuple[Any, ...]],
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> CompilerEngine:
|
||||
"""Main API of hnumpy, to be able to compile an homomorphic program.
|
||||
|
||||
Args:
|
||||
function_to_trace (Callable): The function you want to trace
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedValue holding a 7bits unsigned Integer
|
||||
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
CompilerEngine: engine to run and debug the compiled graph
|
||||
"""
|
||||
# Compile into an OPGraph
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function_to_trace, function_parameters, dataset, compilation_artifacts
|
||||
)
|
||||
|
||||
# Convert graph to an MLIR representation
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
# Compile the MLIR representation
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_result)
|
||||
|
||||
return engine
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from hdk.common.compilation import CompilationArtifacts
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.hnumpy.compile import compile_numpy_function
|
||||
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def test_artifacts_export():
|
||||
@@ -16,7 +16,7 @@ def test_artifacts_export():
|
||||
return x + 42
|
||||
|
||||
artifacts = CompilationArtifacts()
|
||||
compile_numpy_function(
|
||||
compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedValue(Integer(7, True))},
|
||||
iter([(-2,), (-1,), (0,), (1,), (2,)]),
|
||||
|
||||
@@ -10,7 +10,7 @@ from zamalang.dialects import hlfhe
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
||||
from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from hdk.hnumpy.compile import compile_numpy_function
|
||||
from hdk.hnumpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def add(x, y):
|
||||
@@ -168,7 +168,7 @@ def datagen(*args):
|
||||
def test_mlir_converter(func, args_dict, args_ranges):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
dataset = datagen(*args_ranges)
|
||||
result_graph = compile_numpy_function(func, args_dict, dataset)
|
||||
result_graph = compile_numpy_function_into_op_graph(func, args_dict, dataset)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Test file for hnumpy compilation functions"""
|
||||
import itertools
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
@@ -8,7 +9,10 @@ from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.common.debugging import draw_graph, get_printable_graph
|
||||
from hdk.common.extensions.table import LookupTable
|
||||
from hdk.hnumpy.compile import compile_numpy_function
|
||||
from hdk.hnumpy.compile import (
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph,
|
||||
)
|
||||
|
||||
|
||||
def no_fuse_unhandled(x, y):
|
||||
@@ -49,7 +53,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
op_graph = compile_numpy_function(
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
@@ -63,6 +67,36 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 42, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: x * 2, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
"""Test function compile_numpy_function for a program with multiple outputs"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedValue(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
compiler_engine.run(*args)
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu():
|
||||
"""Test compile_numpy_function for a program with direct table lookup"""
|
||||
|
||||
@@ -71,7 +105,7 @@ def test_compile_function_with_direct_tlu():
|
||||
def function(x):
|
||||
return x + table[x]
|
||||
|
||||
op_graph = compile_numpy_function(
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
iter([(0,), (1,), (2,), (3,)]),
|
||||
@@ -90,7 +124,7 @@ def test_compile_function_with_direct_tlu_overflow():
|
||||
return table[x]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
compile_numpy_function(
|
||||
compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedValue(Integer(3, is_signed=False))},
|
||||
iter([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]),
|
||||
@@ -115,7 +149,7 @@ def test_fail_compile(function, input_ranges, list_of_arg_names):
|
||||
}
|
||||
|
||||
with pytest.raises(TypeError, match=r"signed integers aren't supported for MLIR lowering"):
|
||||
compile_numpy_function(
|
||||
compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
|
||||
Reference in New Issue
Block a user