feat(python): Add compilation feedback to the python bindings

This commit is contained in:
Quentin Bourgerie
2022-09-09 23:18:48 +02:00
parent f4673e8276
commit dbfde466bc
10 changed files with 160 additions and 26 deletions

View File

@@ -47,6 +47,10 @@ MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
jit_load_client_parameters(JITSupport_C support,
mlir::concretelang::JitCompilationResult &);
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
jit_load_compilation_feedback(JITSupport_C support,
mlir::concretelang::JitCompilationResult &);
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
jit_load_server_lambda(JITSupport_C support,
mlir::concretelang::JitCompilationResult &);
@@ -66,7 +70,8 @@ typedef struct LibrarySupport_C LibrarySupport_C;
MLIR_CAPI_EXPORTED LibrarySupport_C
library_support(const char *outputPath, const char *runtimeLibraryPath,
bool generateSharedLib, bool generateStaticLib,
bool generateClientParameters, bool generateCppHeader);
bool generateClientParameters, bool generateCompilationFeedback,
bool generateCppHeader);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile(LibrarySupport_C support, const char *module,
@@ -76,6 +81,10 @@ MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
library_load_client_parameters(LibrarySupport_C support,
mlir::concretelang::LibraryCompilationResult &);
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
library_load_compilation_feedback(
LibrarySupport_C support, mlir::concretelang::LibraryCompilationResult &);
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
library_load_server_lambda(LibrarySupport_C support,
mlir::concretelang::LibraryCompilationResult &);

View File

@@ -37,11 +37,13 @@ public:
LibrarySupport(std::string outputPath, std::string runtimeLibraryPath = "",
bool generateSharedLib = true, bool generateStaticLib = true,
bool generateClientParameters = true,
bool generateCompilationFeedback = true,
bool generateCppHeader = true)
: outputPath(outputPath), runtimeLibraryPath(runtimeLibraryPath),
generateSharedLib(generateSharedLib),
generateStaticLib(generateStaticLib),
generateClientParameters(generateClientParameters),
generateCompilationFeedback(generateCompilationFeedback),
generateCppHeader(generateCppHeader) {}
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
@@ -52,9 +54,10 @@ public:
engine.setCompilationOptions(options);
// Compile to a library
auto library = engine.compile(program, outputPath, runtimeLibraryPath,
generateSharedLib, generateStaticLib,
generateClientParameters, generateCppHeader);
auto library = engine.compile(
program, outputPath, runtimeLibraryPath, generateSharedLib,
generateStaticLib, generateClientParameters,
generateCompilationFeedback, generateCppHeader);
if (auto err = library.takeError()) {
return std::move(err);
}
@@ -136,6 +139,7 @@ private:
bool generateSharedLib;
bool generateStaticLib;
bool generateClientParameters;
bool generateCompilationFeedback;
bool generateCppHeader;
};

View File

@@ -29,6 +29,7 @@ declare_mlir_python_sources(ConcretelangBindingsPythonSources
concrete/compiler/__init__.py
concrete/compiler/client_parameters.py
concrete/compiler/client_support.py
concrete/compiler/compilation_feedback.py
concrete/compiler/compilation_options.py
concrete/compiler/jit_compilation_result.py
concrete/compiler/jit_support.py

View File

@@ -73,7 +73,22 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
});
pybind11::class_<mlir::concretelang::CompilationFeedback>(
m, "CompilationFeedback");
m, "CompilationFeedback")
.def_readonly("complexity",
&mlir::concretelang::CompilationFeedback::complexity)
.def_readonly(
"total_secret_keys_size",
&mlir::concretelang::CompilationFeedback::totalSecretKeysSize)
.def_readonly(
"total_bootstrap_keys_size",
&mlir::concretelang::CompilationFeedback::totalBootstrapKeysSize)
.def_readonly(
"total_keyswitch_keys_size",
&mlir::concretelang::CompilationFeedback::totalKeyswitchKeysSize)
.def_readonly("total_inputs_size",
&mlir::concretelang::CompilationFeedback::totalInputsSize)
.def_readonly("total_output_size",
&mlir::concretelang::CompilationFeedback::totalOutputsSize);
pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JITCompilationResult");
@@ -97,7 +112,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def("load_compilation_feedback",
[](JITSupport_C &support,
mlir::concretelang::JitCompilationResult &result) {
return jit_load_client_parameters(support, result);
return jit_load_compilation_feedback(support, result);
})
.def(
"load_server_lambda",
@@ -127,11 +142,12 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(pybind11::init(
[](std::string outputPath, std::string runtimeLibraryPath,
bool generateSharedLib, bool generateStaticLib,
bool generateClientParameters, bool generateCppHeader) {
return library_support(outputPath.c_str(),
runtimeLibraryPath.c_str(),
generateSharedLib, generateStaticLib,
generateClientParameters, generateCppHeader);
bool generateClientParameters, bool generateCompilationFeedback,
bool generateCppHeader) {
return library_support(
outputPath.c_str(), runtimeLibraryPath.c_str(),
generateSharedLib, generateStaticLib, generateClientParameters,
generateCompilationFeedback, generateCppHeader);
}))
.def("compile",
[](LibrarySupport_C &support, std::string mlir_program,

View File

@@ -15,6 +15,7 @@ from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
from .compilation_options import CompilationOptions
from .key_set_cache import KeySetCache
from .client_parameters import ClientParameters
from .compilation_feedback import CompilationFeedback
from .key_set import KeySet
from .public_result import PublicResult
from .public_arguments import PublicArguments

View File

@@ -0,0 +1,40 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
"""Compilation feedback."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
CompilationFeedback as _CompilationFeedback,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class CompilationFeedback(WrapperCpp):
"""CompilationFeedback is a set of hint computed by the compiler engine."""
def __init__(self, compilation_feedback: _CompilationFeedback):
"""Wrap the native Cpp object.
Args:
compilation_feeback (_CompilationFeedback): object to wrap
Raises:
TypeError: if compilation_feedback is not of type _CompilationFeedback
"""
if not isinstance(compilation_feedback, _CompilationFeedback):
raise TypeError(
f"compilation_feedback must be of type _CompilationFeedback, not {type(compilation_feedback)}"
)
self.complexity = compilation_feedback.complexity
self.total_secret_keys_size = compilation_feedback.total_secret_keys_size
self.total_bootstrap_keys_size = compilation_feedback.total_bootstrap_keys_size
self.total_keyswitch_keys_size = compilation_feedback.total_keyswitch_keys_size
self.total_inputs_size = compilation_feedback.total_inputs_size
self.total_output_size = compilation_feedback.total_output_size
super().__init__(compilation_feedback)

View File

@@ -19,6 +19,7 @@ from .utils import lookup_runtime_lib
from .compilation_options import CompilationOptions
from .jit_compilation_result import JITCompilationResult
from .client_parameters import ClientParameters
from .compilation_feedback import CompilationFeedback
from .jit_lambda import JITLambda
from .public_arguments import PublicArguments
from .public_result import PublicResult
@@ -121,6 +122,28 @@ class JITSupport(WrapperCpp):
self.cpp().load_client_parameters(compilation_result.cpp())
)
def load_compilation_feedback(
self, compilation_result: JITCompilationResult
) -> CompilationFeedback:
"""Load the compilation feedback from the JIT compilation result.
Args:
compilation_result (JITCompilationResult): result of the JIT compilation
Raises:
TypeError: if compilation_result is not of type JITCompilationResult
Returns:
CompilationFeedback: the compilation feedback for the compiled program
"""
if not isinstance(compilation_result, JITCompilationResult):
raise TypeError(
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
)
return CompilationFeedback.wrap(
self.cpp().load_compilation_feedback(compilation_result.cpp())
)
def load_server_lambda(self, compilation_result: JITCompilationResult) -> JITLambda:
"""Load the JITLambda from the JIT compilation result.

View File

@@ -21,6 +21,7 @@ from .public_arguments import PublicArguments
from .library_lambda import LibraryLambda
from .public_result import PublicResult
from .client_parameters import ClientParameters
from .compilation_feedback import CompilationFeedback
from .wrapper import WrapperCpp
from .utils import lookup_runtime_lib
from .evaluation_keys import EvaluationKeys
@@ -71,6 +72,7 @@ class LibrarySupport(WrapperCpp):
generateSharedLib: bool = True,
generateStaticLib: bool = False,
generateClientParameters: bool = True,
generateCompilationFeedback: bool = True,
generateCppHeader: bool = False,
) -> "LibrarySupport":
"""Build a LibrarySupport.
@@ -104,6 +106,7 @@ class LibrarySupport(WrapperCpp):
("generateSharedLib", generateSharedLib),
("generateStaticLib", generateStaticLib),
("generateClientParameters", generateClientParameters),
("generateCompilationFeedback", generateCompilationFeedback),
("generateCppHeader", generateCppHeader),
]:
if not isinstance(value, bool):
@@ -115,6 +118,7 @@ class LibrarySupport(WrapperCpp):
generateSharedLib,
generateStaticLib,
generateClientParameters,
generateCompilationFeedback,
generateCppHeader,
)
)
@@ -188,6 +192,28 @@ class LibrarySupport(WrapperCpp):
self.cpp().load_client_parameters(library_compilation_result.cpp())
)
def load_compilation_feedback(
self, compilation_result: LibraryCompilationResult
) -> CompilationFeedback:
"""Load the compilation feedback from the JIT compilation result.
Args:
compilation_result (JITCompilationResult): result of the JIT compilation
Raises:
TypeError: if compilation_result is not of type JITCompilationResult
Returns:
CompilationFeedback: the compilation feedback for the compiled program
"""
if not isinstance(compilation_result, LibraryCompilationResult):
raise TypeError(
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
)
return CompilationFeedback.wrap(
self.cpp().load_compilation_feedback(compilation_result.cpp())
)
def load_server_lambda(
self, library_compilation_result: LibraryCompilationResult
) -> LibraryLambda:

View File

@@ -44,6 +44,14 @@ jit_load_client_parameters(JITSupport_C support,
return *clientParameters;
}
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
jit_load_compilation_feedback(
JITSupport_C support, mlir::concretelang::JitCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
support.support.loadCompilationFeedback(result));
return *compilationFeedback;
}
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
jit_load_server_lambda(JITSupport_C support,
mlir::concretelang::JitCompilationResult &result) {
@@ -64,10 +72,12 @@ jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda,
MLIR_CAPI_EXPORTED LibrarySupport_C
library_support(const char *outputPath, const char *runtimeLibraryPath,
bool generateSharedLib, bool generateStaticLib,
bool generateClientParameters, bool generateCppHeader) {
bool generateClientParameters, bool generateCompilationFeedback,
bool generateCppHeader) {
return LibrarySupport_C{mlir::concretelang::LibrarySupport(
outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib,
generateClientParameters, generateCppHeader)};
generateClientParameters, generateCompilationFeedback,
generateCppHeader)};
}
std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
@@ -87,6 +97,15 @@ library_load_client_parameters(
return *clientParameters;
}
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
library_load_compilation_feedback(
LibrarySupport_C support,
mlir::concretelang::LibraryCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
support.support.loadCompilationFeedback(result));
return *compilationFeedback;
}
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
library_load_server_lambda(
LibrarySupport_C support,

View File

@@ -7,6 +7,7 @@ from concrete.compiler import (
LibrarySupport,
ClientSupport,
CompilationOptions,
CompilationFeedback,
)
@@ -27,19 +28,16 @@ def run(engine, args, compilation_result, keyset_cache):
Perform required loading, encryption, execution, and decryption."""
# Dev
compilation_feedback = engine.load_compilation_feedback(
compilation_result)
assert(compilation_feedback is not None)
compilation_feedback = engine.load_compilation_feedback(compilation_result)
assert isinstance(compilation_feedback, CompilationFeedback)
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
evaluation_keys = key_set.get_evaluation_keys()
public_result = engine.server_call(
server_lambda, public_arguments, evaluation_keys)
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
# Client
result = ClientSupport.decrypt_result(key_set, public_result)
return result
@@ -141,10 +139,8 @@ end_to_end_parallel_fixture = [
}
""",
(
np.array([[1, 2, 3, 4], [4, 2, 1, 0], [
2, 3, 1, 5]], dtype=np.uint8),
np.array([[1, 2, 3, 4], [4, 2, 1, 1], [
2, 3, 1, 5]], dtype=np.uint8),
np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8),
np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]], dtype=np.uint8),
),
np.array([[52, 36], [31, 34], [42, 52]]),
id="matmul_eint_int_uint8",
@@ -228,8 +224,7 @@ def test_lib_compile_and_run_p_error(keyset_cache):
options = CompilationOptions.new("main")
options.set_p_error(0.00001)
options.set_display_optimizer_choice(True)
compile_run_assert(engine, mlir_input, args,
expected_result, keyset_cache, options)
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options)
def test_lib_compile_and_run_p_error(keyset_cache):