mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: separate python bindings wrapper from CAPI
current CAPI of CompilerEngine isn't really a CAPI. It's initial need was for the python bindings to have access to the CompilerEngine through a convenient API. So we now make a clear separation of CAPI and python wrappers. So we now have wrappers functions, that can be implemented using C/C++, and will be exposed to python via pybind11. And we have a CAPI (still need fixing as it still contains C++ code), that can be used as is, or to build bindings for other languages (such as Rust).
This commit is contained in:
@@ -3,8 +3,8 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_PYTHON_COMPILER_API_MODULE_H
|
||||
#define CONCRETELANG_PYTHON_COMPILER_API_MODULE_H
|
||||
#ifndef CONCRETELANG_BINDINGS_PYTHON_COMPILER_API_MODULE_H
|
||||
#define CONCRETELANG_BINDINGS_PYTHON_COMPILER_API_MODULE_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
181
compiler/include/concretelang/Bindings/Python/CompilerEngine.h
Normal file
181
compiler/include/concretelang/Bindings/Python/CompilerEngine.h
Normal file
@@ -0,0 +1,181 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
|
||||
#define CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
|
||||
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/LibrarySupport.h"
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
/// MLIR_CAPI_EXPORTED is used here throughout the API, because of the way the
|
||||
/// python extension is built using MLIR cmake functions, which will cause
|
||||
/// undefined symbols during runtime if those aren't present.
|
||||
|
||||
/// Wrapper of the mlir::concretelang::LambdaArgument
|
||||
struct lambdaArgument {
|
||||
std::shared_ptr<mlir::concretelang::LambdaArgument> ptr;
|
||||
};
|
||||
typedef struct lambdaArgument lambdaArgument;
|
||||
|
||||
/// Hold a list of lambdaArgument to represent execution arguments
|
||||
struct executionArguments {
|
||||
lambdaArgument *data;
|
||||
size_t size;
|
||||
};
|
||||
typedef struct executionArguments executionArguments;
|
||||
|
||||
// JIT Support bindings ///////////////////////////////////////////////////////
|
||||
|
||||
struct JITSupport_Py {
|
||||
mlir::concretelang::JITSupport support;
|
||||
};
|
||||
typedef struct JITSupport_Py JITSupport_Py;
|
||||
|
||||
MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITSupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
jit_load_compilation_feedback(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
|
||||
jit_load_server_lambda(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda,
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
// Library Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
struct LibrarySupport_Py {
|
||||
mlir::concretelang::LibrarySupport support;
|
||||
};
|
||||
typedef struct LibrarySupport_Py LibrarySupport_Py;
|
||||
|
||||
MLIR_CAPI_EXPORTED LibrarySupport_Py
|
||||
library_support(const char *outputPath, const char *runtimeLibraryPath,
|
||||
bool generateSharedLib, bool generateStaticLib,
|
||||
bool generateClientParameters, bool generateCompilationFeedback,
|
||||
bool generateCppHeader);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibrarySupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
library_load_client_parameters(LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
library_load_compilation_feedback(
|
||||
LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
library_server_call(LibrarySupport_Py support,
|
||||
concretelang::serverlib::ServerLambda lambda,
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_shared_lib_path(LibrarySupport_Py support);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_client_parameters_path(LibrarySupport_Py support);
|
||||
|
||||
// Client Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
key_set(concretelang::clientlib::ClientParameters clientParameters,
|
||||
llvm::Optional<concretelang::clientlib::KeySetCache> cache);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args);
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
decrypt_result(concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult);
|
||||
|
||||
// Serialization ////////////////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
clientParametersUnserialize(const std::string &json);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
publicArgumentsUnserialize(
|
||||
mlir::concretelang::ClientParameters &clientParameters,
|
||||
const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize(
|
||||
concretelang::clientlib::PublicArguments &publicArguments);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters,
|
||||
const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
publicResultSerialize(concretelang::clientlib::PublicResult &publicResult);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
|
||||
evaluationKeysUnserialize(const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
/// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
/// Terminate parallelization
|
||||
MLIR_CAPI_EXPORTED void terminateParallelization();
|
||||
|
||||
/// Create a lambdaArgument from a tensor of different data types
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
|
||||
std::vector<uint8_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16(
|
||||
std::vector<uint16_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32(
|
||||
std::vector<uint32_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64(
|
||||
std::vector<uint64_t> data, std::vector<int64_t> dimensions);
|
||||
/// Create a lambdaArgument from a scalar
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar);
|
||||
/// Check if a lambdaArgument holds a tensor
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg);
|
||||
/// Get tensor data from lambdaArgument
|
||||
MLIR_CAPI_EXPORTED std::vector<uint64_t>
|
||||
lambdaArgumentGetTensorData(lambdaArgument &lambda_arg);
|
||||
/// Get tensor dimensions from lambdaArgument
|
||||
MLIR_CAPI_EXPORTED std::vector<int64_t>
|
||||
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg);
|
||||
/// Check if a lambdaArgument holds a scalar
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg);
|
||||
/// Get scalar value from lambdaArgument
|
||||
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg);
|
||||
|
||||
/// Compile the textual representation of MLIR modules to a library.
|
||||
MLIR_CAPI_EXPORTED std::string library(std::string libraryPath,
|
||||
std::vector<std::string> modules);
|
||||
|
||||
#endif // CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
|
||||
@@ -3,8 +3,8 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_PYTHON_DIALECTMODULES_H
|
||||
#define CONCRETELANG_PYTHON_DIALECTMODULES_H
|
||||
#ifndef CONCRETELANG_BINDINGS_PYTHON_DIALECTMODULES_H
|
||||
#define CONCRETELANG_BINDINGS_PYTHON_DIALECTMODULES_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
include(AddMLIRPython)
|
||||
|
||||
# Python bindings need to throw exceptions for proper handling of errors on the python-side
|
||||
add_compile_options(-fexceptions)
|
||||
|
||||
# ######################################################################################################################
|
||||
# Support wrapper library for Python
|
||||
# ######################################################################################################################
|
||||
set(LLVM_OPTIONAL_SOURCES CompilerAPIModule.cpp ConcretelangModule.cpp FHEModule.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(CONCRETELANGPySupport CompilerEngine.cpp LINK_LIBS PUBLIC MLIRCAPIIR ConcretelangSupport)
|
||||
|
||||
# ######################################################################################################################
|
||||
# Decalare native Python extension
|
||||
# ######################################################################################################################
|
||||
@@ -19,7 +29,7 @@ declare_mlir_python_extension(
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
CONCRETELANGCAPIFHE
|
||||
CONCRETELANGCAPIFHELINALG
|
||||
CONCRETELANGCAPISupport)
|
||||
CONCRETELANGPySupport)
|
||||
|
||||
# ######################################################################################################################
|
||||
# Declare python sources
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "CompilerAPIModule.h"
|
||||
#include "concretelang-c/Support/CompilerEngine.h"
|
||||
#include "concretelang/Bindings/Python/CompilerAPIModule.h"
|
||||
#include "concretelang/Bindings/Python/CompilerEngine.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
@@ -95,34 +95,34 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
pybind11::class_<mlir::concretelang::JITLambda,
|
||||
std::shared_ptr<mlir::concretelang::JITLambda>>(m,
|
||||
"JITLambda");
|
||||
pybind11::class_<JITSupport_C>(m, "JITSupport")
|
||||
pybind11::class_<JITSupport_Py>(m, "JITSupport")
|
||||
.def(pybind11::init([](std::string runtimeLibPath) {
|
||||
return jit_support(runtimeLibPath);
|
||||
}))
|
||||
.def("compile",
|
||||
[](JITSupport_C &support, std::string mlir_program,
|
||||
[](JITSupport_Py &support, std::string mlir_program,
|
||||
CompilationOptions options) {
|
||||
return jit_compile(support, mlir_program.c_str(), options);
|
||||
})
|
||||
.def("load_client_parameters",
|
||||
[](JITSupport_C &support,
|
||||
[](JITSupport_Py &support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
return jit_load_client_parameters(support, result);
|
||||
})
|
||||
.def("load_compilation_feedback",
|
||||
[](JITSupport_C &support,
|
||||
[](JITSupport_Py &support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
return jit_load_compilation_feedback(support, result);
|
||||
})
|
||||
.def(
|
||||
"load_server_lambda",
|
||||
[](JITSupport_C &support,
|
||||
[](JITSupport_Py &support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
return jit_load_server_lambda(support, result);
|
||||
},
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
[](JITSupport_C &support, concretelang::JITLambda &lambda,
|
||||
[](JITSupport_Py &support, concretelang::JITLambda &lambda,
|
||||
clientlib::PublicArguments &publicArguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
return jit_server_call(support, lambda, publicArguments,
|
||||
@@ -138,7 +138,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
};
|
||||
}));
|
||||
pybind11::class_<concretelang::serverlib::ServerLambda>(m, "LibraryLambda");
|
||||
pybind11::class_<LibrarySupport_C>(m, "LibrarySupport")
|
||||
pybind11::class_<LibrarySupport_Py>(m, "LibrarySupport")
|
||||
.def(pybind11::init(
|
||||
[](std::string outputPath, std::string runtimeLibraryPath,
|
||||
bool generateSharedLib, bool generateStaticLib,
|
||||
@@ -150,39 +150,39 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
generateCompilationFeedback, generateCppHeader);
|
||||
}))
|
||||
.def("compile",
|
||||
[](LibrarySupport_C &support, std::string mlir_program,
|
||||
[](LibrarySupport_Py &support, std::string mlir_program,
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
return library_compile(support, mlir_program.c_str(), options);
|
||||
})
|
||||
.def("load_client_parameters",
|
||||
[](LibrarySupport_C &support,
|
||||
[](LibrarySupport_Py &support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
return library_load_client_parameters(support, result);
|
||||
})
|
||||
.def("load_compilation_feedback",
|
||||
[](LibrarySupport_C &support,
|
||||
[](LibrarySupport_Py &support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
return library_load_compilation_feedback(support, result);
|
||||
})
|
||||
.def(
|
||||
"load_server_lambda",
|
||||
[](LibrarySupport_C &support,
|
||||
[](LibrarySupport_Py &support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
return library_load_server_lambda(support, result);
|
||||
},
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
[](LibrarySupport_C &support, serverlib::ServerLambda lambda,
|
||||
[](LibrarySupport_Py &support, serverlib::ServerLambda lambda,
|
||||
clientlib::PublicArguments &publicArguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
return library_server_call(support, lambda, publicArguments,
|
||||
evaluationKeys);
|
||||
})
|
||||
.def("get_shared_lib_path",
|
||||
[](LibrarySupport_C &support) {
|
||||
[](LibrarySupport_Py &support) {
|
||||
return library_get_shared_lib_path(support);
|
||||
})
|
||||
.def("get_client_parameters_path", [](LibrarySupport_C &support) {
|
||||
.def("get_client_parameters_path", [](LibrarySupport_Py &support) {
|
||||
return library_get_client_parameters_path(support);
|
||||
});
|
||||
|
||||
|
||||
428
compiler/lib/Bindings/Python/CompilerEngine.cpp
Normal file
428
compiler/lib/Bindings/Python/CompilerEngine.cpp
Normal file
@@ -0,0 +1,428 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
|
||||
#include "concretelang/Bindings/Python/CompilerEngine.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
|
||||
#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
|
||||
auto VARNAME = EXPECTED; \
|
||||
if (auto err = VARNAME.takeError()) { \
|
||||
throw std::runtime_error(llvm::toString(std::move(err))); \
|
||||
}
|
||||
|
||||
// JIT Support bindings ///////////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath) {
|
||||
auto opt = runtimeLibPath.empty()
|
||||
? llvm::None
|
||||
: llvm::Optional<std::string>(runtimeLibPath);
|
||||
return JITSupport_Py{mlir::concretelang::JITSupport(opt)};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITSupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
|
||||
support.support.compile(module, options));
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
|
||||
support.support.loadClientParameters(result));
|
||||
return *clientParameters;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
jit_load_compilation_feedback(
|
||||
JITSupport_Py support, mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
|
||||
support.support.loadCompilationFeedback(result));
|
||||
return *compilationFeedback;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
|
||||
jit_load_server_lambda(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
|
||||
support.support.loadServerLambda(result));
|
||||
return *serverLambda;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda,
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args, evaluationKeys));
|
||||
return std::move(*publicResult);
|
||||
}
|
||||
|
||||
// Library Support bindings ///////////////////////////////////////////////////
|
||||
MLIR_CAPI_EXPORTED LibrarySupport_Py
|
||||
library_support(const char *outputPath, const char *runtimeLibraryPath,
|
||||
bool generateSharedLib, bool generateStaticLib,
|
||||
bool generateClientParameters, bool generateCompilationFeedback,
|
||||
bool generateCppHeader) {
|
||||
return LibrarySupport_Py{mlir::concretelang::LibrarySupport(
|
||||
outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib,
|
||||
generateClientParameters, generateCompilationFeedback,
|
||||
generateCppHeader)};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibrarySupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
|
||||
support.support.compile(module, options));
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
library_load_client_parameters(
|
||||
LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
|
||||
support.support.loadClientParameters(result));
|
||||
return *clientParameters;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
library_load_compilation_feedback(
|
||||
LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
|
||||
support.support.loadCompilationFeedback(result));
|
||||
return *compilationFeedback;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(
|
||||
LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
|
||||
support.support.loadServerLambda(result));
|
||||
return *serverLambda;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
library_server_call(LibrarySupport_Py support,
|
||||
concretelang::serverlib::ServerLambda lambda,
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
publicResult, support.support.serverCall(lambda, args, evaluationKeys));
|
||||
return std::move(*publicResult);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_shared_lib_path(LibrarySupport_Py support) {
|
||||
return support.support.getSharedLibPath();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_client_parameters_path(LibrarySupport_Py support) {
|
||||
return support.support.getClientParametersPath();
|
||||
}
|
||||
|
||||
// Client Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
key_set(concretelang::clientlib::ClientParameters clientParameters,
|
||||
llvm::Optional<concretelang::clientlib::KeySetCache> cache) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
ks, (mlir::concretelang::LambdaSupport<int, int>::keySet(clientParameters,
|
||||
cache)));
|
||||
return std::move(*ks);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
publicArguments,
|
||||
(mlir::concretelang::LambdaSupport<int, int>::exportArguments(
|
||||
clientParameters, keySet, args)));
|
||||
return std::move(*publicArguments);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
decrypt_result(concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
result, mlir::concretelang::typedResult<
|
||||
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
|
||||
keySet, publicResult));
|
||||
lambdaArgument result_{std::move(*result)};
|
||||
return result_;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
publicArgumentsUnserialize(
|
||||
mlir::concretelang::ClientParameters &clientParameters,
|
||||
const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
auto argsOrError = concretelang::clientlib::PublicArguments::unserialize(
|
||||
clientParameters, istream);
|
||||
if (!argsOrError) {
|
||||
throw std::runtime_error(argsOrError.error().mesg);
|
||||
}
|
||||
return std::move(argsOrError.value());
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize(
|
||||
concretelang::clientlib::PublicArguments &publicArguments) {
|
||||
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
auto voidOrError = publicArguments.serialize(buffer);
|
||||
if (!voidOrError) {
|
||||
throw std::runtime_error(voidOrError.error().mesg);
|
||||
}
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters,
|
||||
const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
auto publicResultOrError = concretelang::clientlib::PublicResult::unserialize(
|
||||
clientParameters, istream);
|
||||
if (!publicResultOrError) {
|
||||
throw std::runtime_error(publicResultOrError.error().mesg);
|
||||
}
|
||||
return std::move(publicResultOrError.value());
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) {
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
auto voidOrError = publicResult.serialize(buffer);
|
||||
if (!voidOrError) {
|
||||
throw std::runtime_error(voidOrError.error().mesg);
|
||||
}
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
|
||||
evaluationKeysUnserialize(const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
|
||||
concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
concretelang::clientlib::operator>>(istream, evaluationKeys);
|
||||
|
||||
if (istream.fail()) {
|
||||
throw std::runtime_error("Cannot read evaluation keys");
|
||||
}
|
||||
|
||||
return evaluationKeys;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
concretelang::clientlib::operator<<(buffer, evaluationKeys);
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
clientParametersUnserialize(const std::string &json) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
clientParams,
|
||||
llvm::json::parse<mlir::concretelang::ClientParameters>(json));
|
||||
return clientParams.get();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms) {
|
||||
llvm::json::Value value(params);
|
||||
std::string jsonParams;
|
||||
llvm::raw_string_ostream buffer(jsonParams);
|
||||
buffer << value;
|
||||
return jsonParams;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED void terminateParallelization() {
|
||||
#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED
|
||||
_dfr_terminate();
|
||||
#endif
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module) {
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
mlir::concretelang::CompilerEngine ce{ccx};
|
||||
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
|
||||
llvm::Expected<mlir::concretelang::CompilerEngine::CompilationResult>
|
||||
retOrErr = ce.compile(
|
||||
module, mlir::concretelang::CompilerEngine::Target::ROUND_TRIP);
|
||||
if (!retOrErr) {
|
||||
os << "MLIR parsing failed: " << llvm::toString(retOrErr.takeError());
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
|
||||
retOrErr->mlirModuleRef->get().print(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) {
|
||||
return lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MLIR_CAPI_EXPORTED std::vector<uint64_t> copyTensorLambdaArgumentTo64bitsvector(
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<T>> *tensor) {
|
||||
auto numElements = tensor->getNumElements();
|
||||
if (!numElements) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(std::move(numElements.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<uint64_t> res;
|
||||
res.reserve(*numElements);
|
||||
T *data = tensor->getValue();
|
||||
for (size_t i = 0; i < *numElements; i++) {
|
||||
res.push_back(data[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::vector<uint64_t>
|
||||
lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
|
||||
if (!sizeOrErr) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(sizeOrErr.takeError());
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
|
||||
return data;
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector(arg);
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::vector<int64_t>
|
||||
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) {
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) {
|
||||
return lambda_arg.ptr->isa<mlir::concretelang::IntLambdaArgument<uint64_t>>();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
lambdaArgumentGetScalar(lambdaArgument &lambda_arg) {
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t> *arg =
|
||||
lambda_arg.ptr
|
||||
->dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
|
||||
if (arg == nullptr) {
|
||||
throw std::invalid_argument("LambdaArgument isn't a scalar, should "
|
||||
"be an IntLambdaArgument<uint64_t>");
|
||||
}
|
||||
return arg->getValue();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
|
||||
std::vector<uint8_t> data, std::vector<int64_t> dimensions) {
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>(data, dimensions)};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16(
|
||||
std::vector<uint16_t> data, std::vector<int64_t> dimensions) {
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>(data, dimensions)};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32(
|
||||
std::vector<uint32_t> data, std::vector<int64_t> dimensions) {
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>(data, dimensions)};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64(
|
||||
std::vector<uint64_t> data, std::vector<int64_t> dimensions) {
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>(data, dimensions)};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) {
|
||||
lambdaArgument scalar_arg{
|
||||
std::make_shared<mlir::concretelang::IntLambdaArgument<uint64_t>>(
|
||||
scalar)};
|
||||
return scalar_arg;
|
||||
}
|
||||
@@ -3,11 +3,10 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "CompilerAPIModule.h"
|
||||
#include "DialectModules.h"
|
||||
|
||||
#include "concretelang-c/Dialect/FHE.h"
|
||||
#include "concretelang-c/Dialect/FHELinalg.h"
|
||||
#include "concretelang/Bindings/Python/CompilerAPIModule.h"
|
||||
#include "concretelang/Bindings/Python/DialectModules.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "DialectModules.h"
|
||||
|
||||
#include "concretelang-c/Dialect/FHE.h"
|
||||
#include "concretelang/Bindings/Python/DialectModules.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
|
||||
Reference in New Issue
Block a user