mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(python): Add compilation feedback to the python bindings
This commit is contained in:
@@ -47,6 +47,10 @@ MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
jit_load_compilation_feedback(JITSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
|
||||
jit_load_server_lambda(JITSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
@@ -66,7 +70,8 @@ typedef struct LibrarySupport_C LibrarySupport_C;
|
||||
MLIR_CAPI_EXPORTED LibrarySupport_C
|
||||
library_support(const char *outputPath, const char *runtimeLibraryPath,
|
||||
bool generateSharedLib, bool generateStaticLib,
|
||||
bool generateClientParameters, bool generateCppHeader);
|
||||
bool generateClientParameters, bool generateCompilationFeedback,
|
||||
bool generateCppHeader);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibrarySupport_C support, const char *module,
|
||||
@@ -76,6 +81,10 @@ MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
library_load_client_parameters(LibrarySupport_C support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
library_load_compilation_feedback(
|
||||
LibrarySupport_C support, mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(LibrarySupport_C support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
@@ -37,11 +37,13 @@ public:
|
||||
LibrarySupport(std::string outputPath, std::string runtimeLibraryPath = "",
|
||||
bool generateSharedLib = true, bool generateStaticLib = true,
|
||||
bool generateClientParameters = true,
|
||||
bool generateCompilationFeedback = true,
|
||||
bool generateCppHeader = true)
|
||||
: outputPath(outputPath), runtimeLibraryPath(runtimeLibraryPath),
|
||||
generateSharedLib(generateSharedLib),
|
||||
generateStaticLib(generateStaticLib),
|
||||
generateClientParameters(generateClientParameters),
|
||||
generateCompilationFeedback(generateCompilationFeedback),
|
||||
generateCppHeader(generateCppHeader) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
@@ -52,9 +54,10 @@ public:
|
||||
engine.setCompilationOptions(options);
|
||||
|
||||
// Compile to a library
|
||||
auto library = engine.compile(program, outputPath, runtimeLibraryPath,
|
||||
generateSharedLib, generateStaticLib,
|
||||
generateClientParameters, generateCppHeader);
|
||||
auto library = engine.compile(
|
||||
program, outputPath, runtimeLibraryPath, generateSharedLib,
|
||||
generateStaticLib, generateClientParameters,
|
||||
generateCompilationFeedback, generateCppHeader);
|
||||
if (auto err = library.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
@@ -136,6 +139,7 @@ private:
|
||||
bool generateSharedLib;
|
||||
bool generateStaticLib;
|
||||
bool generateClientParameters;
|
||||
bool generateCompilationFeedback;
|
||||
bool generateCppHeader;
|
||||
};
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ declare_mlir_python_sources(ConcretelangBindingsPythonSources
|
||||
concrete/compiler/__init__.py
|
||||
concrete/compiler/client_parameters.py
|
||||
concrete/compiler/client_support.py
|
||||
concrete/compiler/compilation_feedback.py
|
||||
concrete/compiler/compilation_options.py
|
||||
concrete/compiler/jit_compilation_result.py
|
||||
concrete/compiler/jit_support.py
|
||||
|
||||
@@ -73,7 +73,22 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
});
|
||||
|
||||
pybind11::class_<mlir::concretelang::CompilationFeedback>(
|
||||
m, "CompilationFeedback");
|
||||
m, "CompilationFeedback")
|
||||
.def_readonly("complexity",
|
||||
&mlir::concretelang::CompilationFeedback::complexity)
|
||||
.def_readonly(
|
||||
"total_secret_keys_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalSecretKeysSize)
|
||||
.def_readonly(
|
||||
"total_bootstrap_keys_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalBootstrapKeysSize)
|
||||
.def_readonly(
|
||||
"total_keyswitch_keys_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalKeyswitchKeysSize)
|
||||
.def_readonly("total_inputs_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalInputsSize)
|
||||
.def_readonly("total_output_size",
|
||||
&mlir::concretelang::CompilationFeedback::totalOutputsSize);
|
||||
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
m, "JITCompilationResult");
|
||||
@@ -97,7 +112,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def("load_compilation_feedback",
|
||||
[](JITSupport_C &support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
return jit_load_client_parameters(support, result);
|
||||
return jit_load_compilation_feedback(support, result);
|
||||
})
|
||||
.def(
|
||||
"load_server_lambda",
|
||||
@@ -127,11 +142,12 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def(pybind11::init(
|
||||
[](std::string outputPath, std::string runtimeLibraryPath,
|
||||
bool generateSharedLib, bool generateStaticLib,
|
||||
bool generateClientParameters, bool generateCppHeader) {
|
||||
return library_support(outputPath.c_str(),
|
||||
runtimeLibraryPath.c_str(),
|
||||
generateSharedLib, generateStaticLib,
|
||||
generateClientParameters, generateCppHeader);
|
||||
bool generateClientParameters, bool generateCompilationFeedback,
|
||||
bool generateCppHeader) {
|
||||
return library_support(
|
||||
outputPath.c_str(), runtimeLibraryPath.c_str(),
|
||||
generateSharedLib, generateStaticLib, generateClientParameters,
|
||||
generateCompilationFeedback, generateCppHeader);
|
||||
}))
|
||||
.def("compile",
|
||||
[](LibrarySupport_C &support, std::string mlir_program,
|
||||
|
||||
@@ -15,6 +15,7 @@ from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip
|
||||
from .compilation_options import CompilationOptions
|
||||
from .key_set_cache import KeySetCache
|
||||
from .client_parameters import ClientParameters
|
||||
from .compilation_feedback import CompilationFeedback
|
||||
from .key_set import KeySet
|
||||
from .public_result import PublicResult
|
||||
from .public_arguments import PublicArguments
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
|
||||
|
||||
"""Compilation feedback."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
CompilationFeedback as _CompilationFeedback,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class CompilationFeedback(WrapperCpp):
|
||||
"""CompilationFeedback is a set of hint computed by the compiler engine."""
|
||||
|
||||
def __init__(self, compilation_feedback: _CompilationFeedback):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
compilation_feeback (_CompilationFeedback): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_feedback is not of type _CompilationFeedback
|
||||
"""
|
||||
if not isinstance(compilation_feedback, _CompilationFeedback):
|
||||
raise TypeError(
|
||||
f"compilation_feedback must be of type _CompilationFeedback, not {type(compilation_feedback)}"
|
||||
)
|
||||
|
||||
self.complexity = compilation_feedback.complexity
|
||||
self.total_secret_keys_size = compilation_feedback.total_secret_keys_size
|
||||
self.total_bootstrap_keys_size = compilation_feedback.total_bootstrap_keys_size
|
||||
self.total_keyswitch_keys_size = compilation_feedback.total_keyswitch_keys_size
|
||||
self.total_inputs_size = compilation_feedback.total_inputs_size
|
||||
self.total_output_size = compilation_feedback.total_output_size
|
||||
|
||||
super().__init__(compilation_feedback)
|
||||
@@ -19,6 +19,7 @@ from .utils import lookup_runtime_lib
|
||||
from .compilation_options import CompilationOptions
|
||||
from .jit_compilation_result import JITCompilationResult
|
||||
from .client_parameters import ClientParameters
|
||||
from .compilation_feedback import CompilationFeedback
|
||||
from .jit_lambda import JITLambda
|
||||
from .public_arguments import PublicArguments
|
||||
from .public_result import PublicResult
|
||||
@@ -121,6 +122,28 @@ class JITSupport(WrapperCpp):
|
||||
self.cpp().load_client_parameters(compilation_result.cpp())
|
||||
)
|
||||
|
||||
def load_compilation_feedback(
|
||||
self, compilation_result: JITCompilationResult
|
||||
) -> CompilationFeedback:
|
||||
"""Load the compilation feedback from the JIT compilation result.
|
||||
|
||||
Args:
|
||||
compilation_result (JITCompilationResult): result of the JIT compilation
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_result is not of type JITCompilationResult
|
||||
|
||||
Returns:
|
||||
CompilationFeedback: the compilation feedback for the compiled program
|
||||
"""
|
||||
if not isinstance(compilation_result, JITCompilationResult):
|
||||
raise TypeError(
|
||||
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
|
||||
)
|
||||
return CompilationFeedback.wrap(
|
||||
self.cpp().load_compilation_feedback(compilation_result.cpp())
|
||||
)
|
||||
|
||||
def load_server_lambda(self, compilation_result: JITCompilationResult) -> JITLambda:
|
||||
"""Load the JITLambda from the JIT compilation result.
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from .public_arguments import PublicArguments
|
||||
from .library_lambda import LibraryLambda
|
||||
from .public_result import PublicResult
|
||||
from .client_parameters import ClientParameters
|
||||
from .compilation_feedback import CompilationFeedback
|
||||
from .wrapper import WrapperCpp
|
||||
from .utils import lookup_runtime_lib
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
@@ -71,6 +72,7 @@ class LibrarySupport(WrapperCpp):
|
||||
generateSharedLib: bool = True,
|
||||
generateStaticLib: bool = False,
|
||||
generateClientParameters: bool = True,
|
||||
generateCompilationFeedback: bool = True,
|
||||
generateCppHeader: bool = False,
|
||||
) -> "LibrarySupport":
|
||||
"""Build a LibrarySupport.
|
||||
@@ -104,6 +106,7 @@ class LibrarySupport(WrapperCpp):
|
||||
("generateSharedLib", generateSharedLib),
|
||||
("generateStaticLib", generateStaticLib),
|
||||
("generateClientParameters", generateClientParameters),
|
||||
("generateCompilationFeedback", generateCompilationFeedback),
|
||||
("generateCppHeader", generateCppHeader),
|
||||
]:
|
||||
if not isinstance(value, bool):
|
||||
@@ -115,6 +118,7 @@ class LibrarySupport(WrapperCpp):
|
||||
generateSharedLib,
|
||||
generateStaticLib,
|
||||
generateClientParameters,
|
||||
generateCompilationFeedback,
|
||||
generateCppHeader,
|
||||
)
|
||||
)
|
||||
@@ -188,6 +192,28 @@ class LibrarySupport(WrapperCpp):
|
||||
self.cpp().load_client_parameters(library_compilation_result.cpp())
|
||||
)
|
||||
|
||||
def load_compilation_feedback(
|
||||
self, compilation_result: LibraryCompilationResult
|
||||
) -> CompilationFeedback:
|
||||
"""Load the compilation feedback from the JIT compilation result.
|
||||
|
||||
Args:
|
||||
compilation_result (JITCompilationResult): result of the JIT compilation
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_result is not of type JITCompilationResult
|
||||
|
||||
Returns:
|
||||
CompilationFeedback: the compilation feedback for the compiled program
|
||||
"""
|
||||
if not isinstance(compilation_result, LibraryCompilationResult):
|
||||
raise TypeError(
|
||||
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
|
||||
)
|
||||
return CompilationFeedback.wrap(
|
||||
self.cpp().load_compilation_feedback(compilation_result.cpp())
|
||||
)
|
||||
|
||||
def load_server_lambda(
|
||||
self, library_compilation_result: LibraryCompilationResult
|
||||
) -> LibraryLambda:
|
||||
|
||||
@@ -44,6 +44,14 @@ jit_load_client_parameters(JITSupport_C support,
|
||||
return *clientParameters;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
jit_load_compilation_feedback(
|
||||
JITSupport_C support, mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
|
||||
support.support.loadCompilationFeedback(result));
|
||||
return *compilationFeedback;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
|
||||
jit_load_server_lambda(JITSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
@@ -64,10 +72,12 @@ jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda,
|
||||
MLIR_CAPI_EXPORTED LibrarySupport_C
|
||||
library_support(const char *outputPath, const char *runtimeLibraryPath,
|
||||
bool generateSharedLib, bool generateStaticLib,
|
||||
bool generateClientParameters, bool generateCppHeader) {
|
||||
bool generateClientParameters, bool generateCompilationFeedback,
|
||||
bool generateCppHeader) {
|
||||
return LibrarySupport_C{mlir::concretelang::LibrarySupport(
|
||||
outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib,
|
||||
generateClientParameters, generateCppHeader)};
|
||||
generateClientParameters, generateCompilationFeedback,
|
||||
generateCppHeader)};
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
@@ -87,6 +97,15 @@ library_load_client_parameters(
|
||||
return *clientParameters;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
library_load_compilation_feedback(
|
||||
LibrarySupport_C support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
|
||||
support.support.loadCompilationFeedback(result));
|
||||
return *compilationFeedback;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(
|
||||
LibrarySupport_C support,
|
||||
|
||||
@@ -7,6 +7,7 @@ from concrete.compiler import (
|
||||
LibrarySupport,
|
||||
ClientSupport,
|
||||
CompilationOptions,
|
||||
CompilationFeedback,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,19 +28,16 @@ def run(engine, args, compilation_result, keyset_cache):
|
||||
|
||||
Perform required loading, encryption, execution, and decryption."""
|
||||
# Dev
|
||||
compilation_feedback = engine.load_compilation_feedback(
|
||||
compilation_result)
|
||||
assert(compilation_feedback is not None)
|
||||
compilation_feedback = engine.load_compilation_feedback(compilation_result)
|
||||
assert isinstance(compilation_feedback, CompilationFeedback)
|
||||
# Client
|
||||
client_parameters = engine.load_client_parameters(compilation_result)
|
||||
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
|
||||
public_arguments = ClientSupport.encrypt_arguments(
|
||||
client_parameters, key_set, args)
|
||||
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
|
||||
# Server
|
||||
server_lambda = engine.load_server_lambda(compilation_result)
|
||||
evaluation_keys = key_set.get_evaluation_keys()
|
||||
public_result = engine.server_call(
|
||||
server_lambda, public_arguments, evaluation_keys)
|
||||
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
|
||||
# Client
|
||||
result = ClientSupport.decrypt_result(key_set, public_result)
|
||||
return result
|
||||
@@ -141,10 +139,8 @@ end_to_end_parallel_fixture = [
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([[1, 2, 3, 4], [4, 2, 1, 0], [
|
||||
2, 3, 1, 5]], dtype=np.uint8),
|
||||
np.array([[1, 2, 3, 4], [4, 2, 1, 1], [
|
||||
2, 3, 1, 5]], dtype=np.uint8),
|
||||
np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8),
|
||||
np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]], dtype=np.uint8),
|
||||
),
|
||||
np.array([[52, 36], [31, 34], [42, 52]]),
|
||||
id="matmul_eint_int_uint8",
|
||||
@@ -228,8 +224,7 @@ def test_lib_compile_and_run_p_error(keyset_cache):
|
||||
options = CompilationOptions.new("main")
|
||||
options.set_p_error(0.00001)
|
||||
options.set_display_optimizer_choice(True)
|
||||
compile_run_assert(engine, mlir_input, args,
|
||||
expected_result, keyset_cache, options)
|
||||
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options)
|
||||
|
||||
|
||||
def test_lib_compile_and_run_p_error(keyset_cache):
|
||||
|
||||
Reference in New Issue
Block a user