diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.h b/compiler/include/concretelang/Bindings/Python/CompilerAPIModule.h similarity index 80% rename from compiler/lib/Bindings/Python/CompilerAPIModule.h rename to compiler/include/concretelang/Bindings/Python/CompilerAPIModule.h index 1a6778666..db2328afe 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.h +++ b/compiler/include/concretelang/Bindings/Python/CompilerAPIModule.h @@ -3,8 +3,8 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#ifndef CONCRETELANG_PYTHON_COMPILER_API_MODULE_H -#define CONCRETELANG_PYTHON_COMPILER_API_MODULE_H +#ifndef CONCRETELANG_BINDINGS_PYTHON_COMPILER_API_MODULE_H +#define CONCRETELANG_BINDINGS_PYTHON_COMPILER_API_MODULE_H #include diff --git a/compiler/include/concretelang/Bindings/Python/CompilerEngine.h b/compiler/include/concretelang/Bindings/Python/CompilerEngine.h new file mode 100644 index 000000000..9bd602875 --- /dev/null +++ b/compiler/include/concretelang/Bindings/Python/CompilerEngine.h @@ -0,0 +1,181 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H +#define CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H + +#include "concretelang/Support/CompilerEngine.h" +#include "concretelang/Support/JITSupport.h" +#include "concretelang/Support/Jit.h" +#include "concretelang/Support/LibrarySupport.h" +#include "mlir-c/IR.h" + +/// MLIR_CAPI_EXPORTED is used here throughout the API, because of the way the +/// python extension is built using MLIR cmake functions, which will cause +/// undefined symbols during runtime if those aren't present. + +/// Wrapper of the mlir::concretelang::LambdaArgument +struct lambdaArgument { + std::shared_ptr ptr; +}; +typedef struct lambdaArgument lambdaArgument; + +/// Hold a list of lambdaArgument to represent execution arguments +struct executionArguments { + lambdaArgument *data; + size_t size; +}; +typedef struct executionArguments executionArguments; + +// JIT Support bindings /////////////////////////////////////////////////////// + +struct JITSupport_Py { + mlir::concretelang::JITSupport support; +}; +typedef struct JITSupport_Py JITSupport_Py; + +MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath); + +MLIR_CAPI_EXPORTED std::unique_ptr +jit_compile(JITSupport_Py support, const char *module, + mlir::concretelang::CompilationOptions options); + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +jit_load_client_parameters(JITSupport_Py support, + mlir::concretelang::JitCompilationResult &); + +MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback +jit_load_compilation_feedback(JITSupport_Py support, + mlir::concretelang::JitCompilationResult &); + +MLIR_CAPI_EXPORTED std::shared_ptr +jit_load_server_lambda(JITSupport_Py support, + mlir::concretelang::JitCompilationResult &); + +MLIR_CAPI_EXPORTED std::unique_ptr +jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda, + concretelang::clientlib::PublicArguments &args, + concretelang::clientlib::EvaluationKeys &evaluationKeys); + +// Library Support bindings /////////////////////////////////////////////////// + +struct LibrarySupport_Py { + mlir::concretelang::LibrarySupport support; +}; +typedef struct LibrarySupport_Py LibrarySupport_Py; + +MLIR_CAPI_EXPORTED LibrarySupport_Py +library_support(const char *outputPath, const char *runtimeLibraryPath, + bool generateSharedLib, bool generateStaticLib, + bool generateClientParameters, bool generateCompilationFeedback, + bool generateCppHeader); + +MLIR_CAPI_EXPORTED std::unique_ptr +library_compile(LibrarySupport_Py support, const char *module, + mlir::concretelang::CompilationOptions options); + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +library_load_client_parameters(LibrarySupport_Py support, + mlir::concretelang::LibraryCompilationResult &); + +MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback +library_load_compilation_feedback( + LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &); + +MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda +library_load_server_lambda(LibrarySupport_Py support, + mlir::concretelang::LibraryCompilationResult &); + +MLIR_CAPI_EXPORTED std::unique_ptr +library_server_call(LibrarySupport_Py support, + concretelang::serverlib::ServerLambda lambda, + concretelang::clientlib::PublicArguments &args, + concretelang::clientlib::EvaluationKeys &evaluationKeys); + +MLIR_CAPI_EXPORTED std::string +library_get_shared_lib_path(LibrarySupport_Py support); + +MLIR_CAPI_EXPORTED std::string +library_get_client_parameters_path(LibrarySupport_Py support); + +// Client Support bindings /////////////////////////////////////////////////// + +MLIR_CAPI_EXPORTED std::unique_ptr +key_set(concretelang::clientlib::ClientParameters clientParameters, + llvm::Optional cache); + +MLIR_CAPI_EXPORTED std::unique_ptr +encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, + concretelang::clientlib::KeySet &keySet, + llvm::ArrayRef args); + +MLIR_CAPI_EXPORTED lambdaArgument +decrypt_result(concretelang::clientlib::KeySet &keySet, + concretelang::clientlib::PublicResult &publicResult); + +// Serialization //////////////////////////////////////////////////////////// + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +clientParametersUnserialize(const std::string &json); + +MLIR_CAPI_EXPORTED std::string +clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms); + +MLIR_CAPI_EXPORTED std::unique_ptr +publicArgumentsUnserialize( + mlir::concretelang::ClientParameters &clientParameters, + const std::string &buffer); + +MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize( + concretelang::clientlib::PublicArguments &publicArguments); + +MLIR_CAPI_EXPORTED std::unique_ptr +publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters, + const std::string &buffer); + +MLIR_CAPI_EXPORTED std::string +publicResultSerialize(concretelang::clientlib::PublicResult &publicResult); + +MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys +evaluationKeysUnserialize(const std::string &buffer); + +MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize( + concretelang::clientlib::EvaluationKeys &evaluationKeys); + +/// Parse then print a textual representation of an MLIR module +MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); + +/// Terminate parallelization +MLIR_CAPI_EXPORTED void terminateParallelization(); + +/// Create a lambdaArgument from a tensor of different data types +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8( + std::vector data, std::vector dimensions); +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16( + std::vector data, std::vector dimensions); +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32( + std::vector data, std::vector dimensions); +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64( + std::vector data, std::vector dimensions); +/// Create a lambdaArgument from a scalar +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar); +/// Check if a lambdaArgument holds a tensor +MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg); +/// Get tensor data from lambdaArgument +MLIR_CAPI_EXPORTED std::vector +lambdaArgumentGetTensorData(lambdaArgument &lambda_arg); +/// Get tensor dimensions from lambdaArgument +MLIR_CAPI_EXPORTED std::vector +lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg); +/// Check if a lambdaArgument holds a scalar +MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg); +/// Get scalar value from lambdaArgument +MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg); + +/// Compile the textual representation of MLIR modules to a library. +MLIR_CAPI_EXPORTED std::string library(std::string libraryPath, + std::vector modules); + +#endif // CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H diff --git a/compiler/lib/Bindings/Python/DialectModules.h b/compiler/include/concretelang/Bindings/Python/DialectModules.h similarity index 81% rename from compiler/lib/Bindings/Python/DialectModules.h rename to compiler/include/concretelang/Bindings/Python/DialectModules.h index dcf164543..06441216e 100644 --- a/compiler/lib/Bindings/Python/DialectModules.h +++ b/compiler/include/concretelang/Bindings/Python/DialectModules.h @@ -3,8 +3,8 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#ifndef CONCRETELANG_PYTHON_DIALECTMODULES_H -#define CONCRETELANG_PYTHON_DIALECTMODULES_H +#ifndef CONCRETELANG_BINDINGS_PYTHON_DIALECTMODULES_H +#define CONCRETELANG_BINDINGS_PYTHON_DIALECTMODULES_H #include diff --git a/compiler/lib/Bindings/Python/CMakeLists.txt b/compiler/lib/Bindings/Python/CMakeLists.txt index c6f605428..f5c6bb01c 100644 --- a/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compiler/lib/Bindings/Python/CMakeLists.txt @@ -1,5 +1,15 @@ include(AddMLIRPython) +# Python bindings need to throw exceptions for proper handling of errors on the python-side +add_compile_options(-fexceptions) + +# ###################################################################################################################### +# Support wrapper library for Python +# ###################################################################################################################### +set(LLVM_OPTIONAL_SOURCES CompilerAPIModule.cpp ConcretelangModule.cpp FHEModule.cpp) + +add_mlir_public_c_api_library(CONCRETELANGPySupport CompilerEngine.cpp LINK_LIBS PUBLIC MLIRCAPIIR ConcretelangSupport) + # ###################################################################################################################### # Decalare native Python extension # ###################################################################################################################### @@ -19,7 +29,7 @@ declare_mlir_python_extension( EMBED_CAPI_LINK_LIBS CONCRETELANGCAPIFHE CONCRETELANGCAPIFHELINALG - CONCRETELANGCAPISupport) + CONCRETELANGPySupport) # ###################################################################################################################### # Declare python sources diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index cafbd8b76..0aa395f03 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -3,8 +3,8 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#include "CompilerAPIModule.h" -#include "concretelang-c/Support/CompilerEngine.h" +#include "concretelang/Bindings/Python/CompilerAPIModule.h" +#include "concretelang/Bindings/Python/CompilerEngine.h" #include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc" #include "concretelang/Support/JITSupport.h" #include "concretelang/Support/Jit.h" @@ -95,34 +95,34 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::class_>(m, "JITLambda"); - pybind11::class_(m, "JITSupport") + pybind11::class_(m, "JITSupport") .def(pybind11::init([](std::string runtimeLibPath) { return jit_support(runtimeLibPath); })) .def("compile", - [](JITSupport_C &support, std::string mlir_program, + [](JITSupport_Py &support, std::string mlir_program, CompilationOptions options) { return jit_compile(support, mlir_program.c_str(), options); }) .def("load_client_parameters", - [](JITSupport_C &support, + [](JITSupport_Py &support, mlir::concretelang::JitCompilationResult &result) { return jit_load_client_parameters(support, result); }) .def("load_compilation_feedback", - [](JITSupport_C &support, + [](JITSupport_Py &support, mlir::concretelang::JitCompilationResult &result) { return jit_load_compilation_feedback(support, result); }) .def( "load_server_lambda", - [](JITSupport_C &support, + [](JITSupport_Py &support, mlir::concretelang::JitCompilationResult &result) { return jit_load_server_lambda(support, result); }, pybind11::return_value_policy::reference) .def("server_call", - [](JITSupport_C &support, concretelang::JITLambda &lambda, + [](JITSupport_Py &support, concretelang::JITLambda &lambda, clientlib::PublicArguments &publicArguments, clientlib::EvaluationKeys &evaluationKeys) { return jit_server_call(support, lambda, publicArguments, @@ -138,7 +138,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( }; })); pybind11::class_(m, "LibraryLambda"); - pybind11::class_(m, "LibrarySupport") + pybind11::class_(m, "LibrarySupport") .def(pybind11::init( [](std::string outputPath, std::string runtimeLibraryPath, bool generateSharedLib, bool generateStaticLib, @@ -150,39 +150,39 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( generateCompilationFeedback, generateCppHeader); })) .def("compile", - [](LibrarySupport_C &support, std::string mlir_program, + [](LibrarySupport_Py &support, std::string mlir_program, mlir::concretelang::CompilationOptions options) { return library_compile(support, mlir_program.c_str(), options); }) .def("load_client_parameters", - [](LibrarySupport_C &support, + [](LibrarySupport_Py &support, mlir::concretelang::LibraryCompilationResult &result) { return library_load_client_parameters(support, result); }) .def("load_compilation_feedback", - [](LibrarySupport_C &support, + [](LibrarySupport_Py &support, mlir::concretelang::LibraryCompilationResult &result) { return library_load_compilation_feedback(support, result); }) .def( "load_server_lambda", - [](LibrarySupport_C &support, + [](LibrarySupport_Py &support, mlir::concretelang::LibraryCompilationResult &result) { return library_load_server_lambda(support, result); }, pybind11::return_value_policy::reference) .def("server_call", - [](LibrarySupport_C &support, serverlib::ServerLambda lambda, + [](LibrarySupport_Py &support, serverlib::ServerLambda lambda, clientlib::PublicArguments &publicArguments, clientlib::EvaluationKeys &evaluationKeys) { return library_server_call(support, lambda, publicArguments, evaluationKeys); }) .def("get_shared_lib_path", - [](LibrarySupport_C &support) { + [](LibrarySupport_Py &support) { return library_get_shared_lib_path(support); }) - .def("get_client_parameters_path", [](LibrarySupport_C &support) { + .def("get_client_parameters_path", [](LibrarySupport_Py &support) { return library_get_client_parameters_path(support); }); diff --git a/compiler/lib/Bindings/Python/CompilerEngine.cpp b/compiler/lib/Bindings/Python/CompilerEngine.cpp new file mode 100644 index 000000000..80598cc1a --- /dev/null +++ b/compiler/lib/Bindings/Python/CompilerEngine.cpp @@ -0,0 +1,428 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "llvm/ADT/SmallString.h" + +#include "concretelang/Bindings/Python/CompilerEngine.h" +#include "concretelang/ClientLib/KeySetCache.h" +#include "concretelang/ClientLib/Serializers.h" +#include "concretelang/Runtime/DFRuntime.hpp" +#include "concretelang/Support/CompilerEngine.h" +#include "concretelang/Support/JITSupport.h" +#include "concretelang/Support/Jit.h" + +#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \ + auto VARNAME = EXPECTED; \ + if (auto err = VARNAME.takeError()) { \ + throw std::runtime_error(llvm::toString(std::move(err))); \ + } + +// JIT Support bindings /////////////////////////////////////////////////////// + +MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath) { + auto opt = runtimeLibPath.empty() + ? llvm::None + : llvm::Optional(runtimeLibPath); + return JITSupport_Py{mlir::concretelang::JITSupport(opt)}; +} + +MLIR_CAPI_EXPORTED std::unique_ptr +jit_compile(JITSupport_Py support, const char *module, + mlir::concretelang::CompilationOptions options) { + GET_OR_THROW_LLVM_EXPECTED(compilationResult, + support.support.compile(module, options)); + return std::move(*compilationResult); +} + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +jit_load_client_parameters(JITSupport_Py support, + mlir::concretelang::JitCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(clientParameters, + support.support.loadClientParameters(result)); + return *clientParameters; +} + +MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback +jit_load_compilation_feedback( + JITSupport_Py support, mlir::concretelang::JitCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, + support.support.loadCompilationFeedback(result)); + return *compilationFeedback; +} + +MLIR_CAPI_EXPORTED std::shared_ptr +jit_load_server_lambda(JITSupport_Py support, + mlir::concretelang::JitCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(serverLambda, + support.support.loadServerLambda(result)); + return *serverLambda; +} + +MLIR_CAPI_EXPORTED std::unique_ptr +jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda, + concretelang::clientlib::PublicArguments &args, + concretelang::clientlib::EvaluationKeys &evaluationKeys) { + GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args, evaluationKeys)); + return std::move(*publicResult); +} + +// Library Support bindings /////////////////////////////////////////////////// +MLIR_CAPI_EXPORTED LibrarySupport_Py +library_support(const char *outputPath, const char *runtimeLibraryPath, + bool generateSharedLib, bool generateStaticLib, + bool generateClientParameters, bool generateCompilationFeedback, + bool generateCppHeader) { + return LibrarySupport_Py{mlir::concretelang::LibrarySupport( + outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib, + generateClientParameters, generateCompilationFeedback, + generateCppHeader)}; +} + +MLIR_CAPI_EXPORTED std::unique_ptr +library_compile(LibrarySupport_Py support, const char *module, + mlir::concretelang::CompilationOptions options) { + GET_OR_THROW_LLVM_EXPECTED(compilationResult, + support.support.compile(module, options)); + return std::move(*compilationResult); +} + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +library_load_client_parameters( + LibrarySupport_Py support, + mlir::concretelang::LibraryCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(clientParameters, + support.support.loadClientParameters(result)); + return *clientParameters; +} + +MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback +library_load_compilation_feedback( + LibrarySupport_Py support, + mlir::concretelang::LibraryCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, + support.support.loadCompilationFeedback(result)); + return *compilationFeedback; +} + +MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda +library_load_server_lambda( + LibrarySupport_Py support, + mlir::concretelang::LibraryCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(serverLambda, + support.support.loadServerLambda(result)); + return *serverLambda; +} + +MLIR_CAPI_EXPORTED std::unique_ptr +library_server_call(LibrarySupport_Py support, + concretelang::serverlib::ServerLambda lambda, + concretelang::clientlib::PublicArguments &args, + concretelang::clientlib::EvaluationKeys &evaluationKeys) { + GET_OR_THROW_LLVM_EXPECTED( + publicResult, support.support.serverCall(lambda, args, evaluationKeys)); + return std::move(*publicResult); +} + +MLIR_CAPI_EXPORTED std::string +library_get_shared_lib_path(LibrarySupport_Py support) { + return support.support.getSharedLibPath(); +} + +MLIR_CAPI_EXPORTED std::string +library_get_client_parameters_path(LibrarySupport_Py support) { + return support.support.getClientParametersPath(); +} + +// Client Support bindings /////////////////////////////////////////////////// + +MLIR_CAPI_EXPORTED std::unique_ptr +key_set(concretelang::clientlib::ClientParameters clientParameters, + llvm::Optional cache) { + GET_OR_THROW_LLVM_EXPECTED( + ks, (mlir::concretelang::LambdaSupport::keySet(clientParameters, + cache))); + return std::move(*ks); +} + +MLIR_CAPI_EXPORTED std::unique_ptr +encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, + concretelang::clientlib::KeySet &keySet, + llvm::ArrayRef args) { + GET_OR_THROW_LLVM_EXPECTED( + publicArguments, + (mlir::concretelang::LambdaSupport::exportArguments( + clientParameters, keySet, args))); + return std::move(*publicArguments); +} + +MLIR_CAPI_EXPORTED lambdaArgument +decrypt_result(concretelang::clientlib::KeySet &keySet, + concretelang::clientlib::PublicResult &publicResult) { + GET_OR_THROW_LLVM_EXPECTED( + result, mlir::concretelang::typedResult< + std::unique_ptr>( + keySet, publicResult)); + lambdaArgument result_{std::move(*result)}; + return result_; +} + +MLIR_CAPI_EXPORTED std::unique_ptr +publicArgumentsUnserialize( + mlir::concretelang::ClientParameters &clientParameters, + const std::string &buffer) { + std::stringstream istream(buffer); + auto argsOrError = concretelang::clientlib::PublicArguments::unserialize( + clientParameters, istream); + if (!argsOrError) { + throw std::runtime_error(argsOrError.error().mesg); + } + return std::move(argsOrError.value()); +} + +MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize( + concretelang::clientlib::PublicArguments &publicArguments) { + + std::ostringstream buffer(std::ios::binary); + auto voidOrError = publicArguments.serialize(buffer); + if (!voidOrError) { + throw std::runtime_error(voidOrError.error().mesg); + } + return buffer.str(); +} + +MLIR_CAPI_EXPORTED std::unique_ptr +publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters, + const std::string &buffer) { + std::stringstream istream(buffer); + auto publicResultOrError = concretelang::clientlib::PublicResult::unserialize( + clientParameters, istream); + if (!publicResultOrError) { + throw std::runtime_error(publicResultOrError.error().mesg); + } + return std::move(publicResultOrError.value()); +} + +MLIR_CAPI_EXPORTED std::string +publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) { + std::ostringstream buffer(std::ios::binary); + auto voidOrError = publicResult.serialize(buffer); + if (!voidOrError) { + throw std::runtime_error(voidOrError.error().mesg); + } + return buffer.str(); +} + +MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys +evaluationKeysUnserialize(const std::string &buffer) { + std::stringstream istream(buffer); + + concretelang::clientlib::EvaluationKeys evaluationKeys; + concretelang::clientlib::operator>>(istream, evaluationKeys); + + if (istream.fail()) { + throw std::runtime_error("Cannot read evaluation keys"); + } + + return evaluationKeys; +} + +MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize( + concretelang::clientlib::EvaluationKeys &evaluationKeys) { + std::ostringstream buffer(std::ios::binary); + concretelang::clientlib::operator<<(buffer, evaluationKeys); + return buffer.str(); +} + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +clientParametersUnserialize(const std::string &json) { + GET_OR_THROW_LLVM_EXPECTED( + clientParams, + llvm::json::parse(json)); + return clientParams.get(); +} + +MLIR_CAPI_EXPORTED std::string +clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms) { + llvm::json::Value value(params); + std::string jsonParams; + llvm::raw_string_ostream buffer(jsonParams); + buffer << value; + return jsonParams; +} + +MLIR_CAPI_EXPORTED void terminateParallelization() { +#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED + _dfr_terminate(); +#endif +} + +MLIR_CAPI_EXPORTED std::string roundTrip(const char *module) { + std::shared_ptr ccx = + mlir::concretelang::CompilationContext::createShared(); + mlir::concretelang::CompilerEngine ce{ccx}; + + std::string backingString; + llvm::raw_string_ostream os(backingString); + + llvm::Expected + retOrErr = ce.compile( + module, mlir::concretelang::CompilerEngine::Target::ROUND_TRIP); + if (!retOrErr) { + os << "MLIR parsing failed: " << llvm::toString(retOrErr.takeError()); + throw std::runtime_error(os.str()); + } + + retOrErr->mlirModuleRef->get().print(os); + return os.str(); +} + +MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) { + return lambda_arg.ptr->isa>>() || + lambda_arg.ptr->isa>>() || + lambda_arg.ptr->isa>>() || + lambda_arg.ptr->isa>>(); +} + +template +MLIR_CAPI_EXPORTED std::vector copyTensorLambdaArgumentTo64bitsvector( + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> *tensor) { + auto numElements = tensor->getNumElements(); + if (!numElements) { + std::string backingString; + llvm::raw_string_ostream os(backingString); + os << "Couldn't get size of tensor: " + << llvm::toString(std::move(numElements.takeError())); + throw std::runtime_error(os.str()); + } + std::vector res; + res.reserve(*numElements); + T *data = tensor->getValue(); + for (size_t i = 0; i < *numElements; i++) { + res.push_back(data[i]); + } + return res; +} + +MLIR_CAPI_EXPORTED std::vector +lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) { + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + llvm::Expected sizeOrErr = arg->getNumElements(); + if (!sizeOrErr) { + std::string backingString; + llvm::raw_string_ostream os(backingString); + os << "Couldn't get size of tensor: " + << llvm::toString(sizeOrErr.takeError()); + throw std::runtime_error(os.str()); + } + std::vector data(arg->getValue(), arg->getValue() + *sizeOrErr); + return data; + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return copyTensorLambdaArgumentTo64bitsvector(arg); + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return copyTensorLambdaArgumentTo64bitsvector(arg); + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return copyTensorLambdaArgumentTo64bitsvector(arg); + } + throw std::invalid_argument( + "LambdaArgument isn't a tensor or has an unsupported bitwidth"); +} + +MLIR_CAPI_EXPORTED std::vector +lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) { + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return arg->getDimensions(); + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return arg->getDimensions(); + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return arg->getDimensions(); + } + if (auto arg = + lambda_arg.ptr->dyn_cast>>()) { + return arg->getDimensions(); + } + throw std::invalid_argument( + "LambdaArgument isn't a tensor, should " + "be a TensorLambdaArgument>"); +} + +MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) { + return lambda_arg.ptr->isa>(); +} + +MLIR_CAPI_EXPORTED uint64_t +lambdaArgumentGetScalar(lambdaArgument &lambda_arg) { + mlir::concretelang::IntLambdaArgument *arg = + lambda_arg.ptr + ->dyn_cast>(); + if (arg == nullptr) { + throw std::invalid_argument("LambdaArgument isn't a scalar, should " + "be an IntLambdaArgument"); + } + return arg->getValue(); +} + +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8( + std::vector data, std::vector dimensions) { + lambdaArgument tensor_arg{ + std::make_shared>>(data, dimensions)}; + return tensor_arg; +} + +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16( + std::vector data, std::vector dimensions) { + lambdaArgument tensor_arg{ + std::make_shared>>(data, dimensions)}; + return tensor_arg; +} + +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32( + std::vector data, std::vector dimensions) { + lambdaArgument tensor_arg{ + std::make_shared>>(data, dimensions)}; + return tensor_arg; +} + +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64( + std::vector data, std::vector dimensions) { + lambdaArgument tensor_arg{ + std::make_shared>>(data, dimensions)}; + return tensor_arg; +} + +MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) { + lambdaArgument scalar_arg{ + std::make_shared>( + scalar)}; + return scalar_arg; +} diff --git a/compiler/lib/Bindings/Python/ConcretelangModule.cpp b/compiler/lib/Bindings/Python/ConcretelangModule.cpp index ce10c38bf..35f49f67e 100644 --- a/compiler/lib/Bindings/Python/ConcretelangModule.cpp +++ b/compiler/lib/Bindings/Python/ConcretelangModule.cpp @@ -3,11 +3,10 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#include "CompilerAPIModule.h" -#include "DialectModules.h" - #include "concretelang-c/Dialect/FHE.h" #include "concretelang-c/Dialect/FHELinalg.h" +#include "concretelang/Bindings/Python/CompilerAPIModule.h" +#include "concretelang/Bindings/Python/DialectModules.h" #include "concretelang/Support/Constants.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Registration.h" diff --git a/compiler/lib/Bindings/Python/FHEModule.cpp b/compiler/lib/Bindings/Python/FHEModule.cpp index 6729d0987..a83e51866 100644 --- a/compiler/lib/Bindings/Python/FHEModule.cpp +++ b/compiler/lib/Bindings/Python/FHEModule.cpp @@ -3,9 +3,8 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#include "DialectModules.h" - #include "concretelang-c/Dialect/FHE.h" +#include "concretelang/Bindings/Python/DialectModules.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir/Bindings/Python/PybindAdaptors.h"