diff --git a/compiler/Makefile b/compiler/Makefile index 9bb447204..148df066b 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -165,7 +165,7 @@ rust-bindings: build-initialized concretecompiler CAPI cd lib/Bindings/Rust && CONCRETE_COMPILER_BUILD_DIR=$(abspath $(BUILD_DIR)) cargo build --release CAPI: - cmake --build $(BUILD_DIR) --target CONCRETELANGCAPIFHE CONCRETELANGCAPIFHELINALG + cmake --build $(BUILD_DIR) --target CONCRETELANGCAPIFHE CONCRETELANGCAPIFHELINALG CONCRETELANGCAPISupport clientlib: build-initialized cmake --build $(BUILD_DIR) --target ConcretelangClientLib diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index b730793de..9fbe67b2c 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -6,182 +6,67 @@ #ifndef CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H #define CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H -#include "concretelang/Support/CompilerEngine.h" -#include "concretelang/Support/JITSupport.h" -#include "concretelang/Support/Jit.h" -#include "concretelang/Support/LibrarySupport.h" #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" -// TODO: make this file C-compatible and uncomment the 3 following lines -// #ifdef __cplusplus -// extern "C" { -// #endif +#ifdef __cplusplus +extern "C" { +#endif -/// C wrapper of the mlir::concretelang::LambdaArgument -struct lambdaArgument { - std::shared_ptr ptr; -}; -typedef struct lambdaArgument lambdaArgument; +/// Opaque type declarations. Refer to llvm-project/mlir/include/mlir-c/IR.h for +/// more info +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name -/// Hold a list of lambdaArgument to represent execution arguments -struct executionArguments { - lambdaArgument *data; - size_t size; -}; -typedef struct executionArguments executionArguments; +DEFINE_C_API_STRUCT(CompilerEngine, void); +DEFINE_C_API_STRUCT(CompilationContext, void); +DEFINE_C_API_STRUCT(CompilationResult, void); -// JIT Support bindings /////////////////////////////////////////////////////// +#undef DEFINE_C_API_STRUCT -struct JITSupport_C { - mlir::concretelang::JITSupport support; -}; -typedef struct JITSupport_C JITSupport_C; +/// NULL Pointer checkers. Generate functions to check if the struct contains a +/// null pointer. +#define DEFINE_NULL_PTR_CHECKER(funcname, storage) \ + bool funcname(storage s) { return s.ptr == NULL; } -MLIR_CAPI_EXPORTED JITSupport_C jit_support(std::string runtimeLibPath); +DEFINE_NULL_PTR_CHECKER(compilerEngineIsNull, CompilerEngine); +DEFINE_NULL_PTR_CHECKER(compilationResultIsNull, CompilationResult); -MLIR_CAPI_EXPORTED std::unique_ptr -jit_compile(JITSupport_C support, const char *module, - mlir::concretelang::CompilationOptions options); +#undef DEFINE_NULL_PTR_CHECKER -MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters -jit_load_client_parameters(JITSupport_C support, - mlir::concretelang::JitCompilationResult &); +/// Each struct has a creator function that allocates memory for the underlying +/// Cpp object referenced, and a destroy function that does free this allocated +/// memory. -MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback -jit_load_compilation_feedback(JITSupport_C support, - mlir::concretelang::JitCompilationResult &); +/// ********** CompilationTarget CAPI ****************************************** -MLIR_CAPI_EXPORTED std::shared_ptr -jit_load_server_lambda(JITSupport_C support, - mlir::concretelang::JitCompilationResult &); +enum CompilationTarget { ROUND_TRIP, OTHER }; +typedef enum CompilationTarget CompilationTarget; -MLIR_CAPI_EXPORTED std::unique_ptr -jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda, - concretelang::clientlib::PublicArguments &args, - concretelang::clientlib::EvaluationKeys &evaluationKeys); +/// ********** CompilerEngine CAPI ********************************************* -// Library Support bindings /////////////////////////////////////////////////// +MLIR_CAPI_EXPORTED CompilerEngine compilerEngineCreate(); -struct LibrarySupport_C { - mlir::concretelang::LibrarySupport support; -}; -typedef struct LibrarySupport_C LibrarySupport_C; +MLIR_CAPI_EXPORTED void compilerEngineDestroy(CompilerEngine engine); -MLIR_CAPI_EXPORTED LibrarySupport_C -library_support(const char *outputPath, const char *runtimeLibraryPath, - bool generateSharedLib, bool generateStaticLib, - bool generateClientParameters, bool generateCompilationFeedback, - bool generateCppHeader); +MLIR_CAPI_EXPORTED CompilationResult compilerEngineCompile( + CompilerEngine engine, MlirStringRef module, CompilationTarget target); -MLIR_CAPI_EXPORTED std::unique_ptr -library_compile(LibrarySupport_C support, const char *module, - mlir::concretelang::CompilationOptions options); +/// ********** CompilationResult CAPI ****************************************** -MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters -library_load_client_parameters(LibrarySupport_C support, - mlir::concretelang::LibraryCompilationResult &); +/// Get a string reference holding the textual representation of the compiled +/// module. The returned `MlirStringRef` should be destroyed using +/// `compilationResultDestroyModuleString` to free memory. +MLIR_CAPI_EXPORTED MlirStringRef +compilationResultGetModuleString(CompilationResult result); -MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback -library_load_compilation_feedback( - LibrarySupport_C support, mlir::concretelang::LibraryCompilationResult &); +/// Free memory allocated for the module string. +MLIR_CAPI_EXPORTED void compilationResultDestroyModuleString(MlirStringRef str); -MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda -library_load_server_lambda(LibrarySupport_C support, - mlir::concretelang::LibraryCompilationResult &); - -MLIR_CAPI_EXPORTED std::unique_ptr -library_server_call(LibrarySupport_C support, - concretelang::serverlib::ServerLambda lambda, - concretelang::clientlib::PublicArguments &args, - concretelang::clientlib::EvaluationKeys &evaluationKeys); - -MLIR_CAPI_EXPORTED std::string -library_get_shared_lib_path(LibrarySupport_C support); - -MLIR_CAPI_EXPORTED std::string -library_get_client_parameters_path(LibrarySupport_C support); - -// Client Support bindings /////////////////////////////////////////////////// - -MLIR_CAPI_EXPORTED std::unique_ptr -key_set(concretelang::clientlib::ClientParameters clientParameters, - llvm::Optional cache); - -MLIR_CAPI_EXPORTED std::unique_ptr -encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, - concretelang::clientlib::KeySet &keySet, - llvm::ArrayRef args); - -MLIR_CAPI_EXPORTED lambdaArgument -decrypt_result(concretelang::clientlib::KeySet &keySet, - concretelang::clientlib::PublicResult &publicResult); - -// Serialization //////////////////////////////////////////////////////////// - -MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters -clientParametersUnserialize(const std::string &json); - -MLIR_CAPI_EXPORTED std::string -clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms); - -MLIR_CAPI_EXPORTED std::unique_ptr -publicArgumentsUnserialize( - mlir::concretelang::ClientParameters &clientParameters, - const std::string &buffer); - -MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize( - concretelang::clientlib::PublicArguments &publicArguments); - -MLIR_CAPI_EXPORTED std::unique_ptr -publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters, - const std::string &buffer); - -MLIR_CAPI_EXPORTED std::string -publicResultSerialize(concretelang::clientlib::PublicResult &publicResult); - -MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys -evaluationKeysUnserialize(const std::string &buffer); - -MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize( - concretelang::clientlib::EvaluationKeys &evaluationKeys); - -/// Parse then print a textual representation of an MLIR module -MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); - -/// Terminate parallelization -MLIR_CAPI_EXPORTED void terminateParallelization(); - -/// Create a lambdaArgument from a tensor of different data types -MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8( - std::vector data, std::vector dimensions); -MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16( - std::vector data, std::vector dimensions); -MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32( - std::vector data, std::vector dimensions); -MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64( - std::vector data, std::vector dimensions); -/// Create a lambdaArgument from a scalar -MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar); -/// Check if a lambdaArgument holds a tensor -MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg); -/// Get tensor data from lambdaArgument -MLIR_CAPI_EXPORTED std::vector -lambdaArgumentGetTensorData(lambdaArgument &lambda_arg); -/// Get tensor dimensions from lambdaArgument -MLIR_CAPI_EXPORTED std::vector -lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg); -/// Check if a lambdaArgument holds a scalar -MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg); -/// Get scalar value from lambdaArgument -MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg); - -/// Compile the textual representation of MLIR modules to a library. -MLIR_CAPI_EXPORTED std::string library(std::string libraryPath, - std::vector modules); - -// #ifdef __cplusplus -// } -// #endif +#ifdef __cplusplus +} +#endif #endif // CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H diff --git a/compiler/include/concretelang/CAPI/Wrappers.h b/compiler/include/concretelang/CAPI/Wrappers.h new file mode 100644 index 000000000..541b6928f --- /dev/null +++ b/compiler/include/concretelang/CAPI/Wrappers.h @@ -0,0 +1,17 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CAPI_WRAPPERS_H +#define CONCRETELANG_CAPI_WRAPPERS_H + +#include "concretelang-c/Support/CompilerEngine.h" +#include "concretelang/Support/CompilerEngine.h" +#include "mlir/CAPI/Wrap.h" + +DEFINE_C_API_PTR_METHODS(CompilerEngine, mlir::concretelang::CompilerEngine) +DEFINE_C_API_PTR_METHODS(CompilationResult, + mlir::concretelang::CompilerEngine::CompilationResult) + +#endif diff --git a/compiler/lib/CAPI/CMakeLists.txt b/compiler/lib/CAPI/CMakeLists.txt index dd5c25301..5825015fe 100644 --- a/compiler/lib/CAPI/CMakeLists.txt +++ b/compiler/lib/CAPI/CMakeLists.txt @@ -1,5 +1,2 @@ -# CAPI is mainly used by python and need to throw exceptions for proper handling of errors on the python-side -add_compile_options(-fexceptions) - add_subdirectory(Dialect) add_subdirectory(Support) diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 253e7f874..a3f360372 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -3,424 +3,56 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#include "llvm/ADT/SmallString.h" - #include "concretelang-c/Support/CompilerEngine.h" -#include "concretelang/ClientLib/KeySetCache.h" -#include "concretelang/ClientLib/Serializers.h" -#include "concretelang/Runtime/DFRuntime.hpp" +#include "concretelang/CAPI/Wrappers.h" #include "concretelang/Support/CompilerEngine.h" -#include "concretelang/Support/JITSupport.h" -#include "concretelang/Support/Jit.h" +#include "mlir/IR/Diagnostics.h" -#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \ - auto VARNAME = EXPECTED; \ - if (auto err = VARNAME.takeError()) { \ - throw std::runtime_error(llvm::toString(std::move(err))); \ - } +/// CompilerEngine CAPI -// JIT Support bindings /////////////////////////////////////////////////////// - -MLIR_CAPI_EXPORTED JITSupport_C jit_support(std::string runtimeLibPath) { - auto opt = runtimeLibPath.empty() - ? llvm::None - : llvm::Optional(runtimeLibPath); - return JITSupport_C{mlir::concretelang::JITSupport(opt)}; +CompilerEngine compilerEngineCreate() { + auto *engine = new mlir::concretelang::CompilerEngine( + mlir::concretelang::CompilationContext::createShared()); + return wrap(engine); } -std::unique_ptr -jit_compile(JITSupport_C support, const char *module, - mlir::concretelang::CompilationOptions options) { - GET_OR_THROW_LLVM_EXPECTED(compilationResult, - support.support.compile(module, options)); - return std::move(*compilationResult); -} +void compilerEngineDestroy(CompilerEngine engine) { delete unwrap(engine); } -MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters -jit_load_client_parameters(JITSupport_C support, - mlir::concretelang::JitCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(clientParameters, - support.support.loadClientParameters(result)); - return *clientParameters; -} - -MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback -jit_load_compilation_feedback( - JITSupport_C support, mlir::concretelang::JitCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, - support.support.loadCompilationFeedback(result)); - return *compilationFeedback; -} - -MLIR_CAPI_EXPORTED std::shared_ptr -jit_load_server_lambda(JITSupport_C support, - mlir::concretelang::JitCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(serverLambda, - support.support.loadServerLambda(result)); - return *serverLambda; -} - -MLIR_CAPI_EXPORTED std::unique_ptr -jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda, - concretelang::clientlib::PublicArguments &args, - concretelang::clientlib::EvaluationKeys &evaluationKeys) { - GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args, evaluationKeys)); - return std::move(*publicResult); -} - -// Library Support bindings /////////////////////////////////////////////////// -MLIR_CAPI_EXPORTED LibrarySupport_C -library_support(const char *outputPath, const char *runtimeLibraryPath, - bool generateSharedLib, bool generateStaticLib, - bool generateClientParameters, bool generateCompilationFeedback, - bool generateCppHeader) { - return LibrarySupport_C{mlir::concretelang::LibrarySupport( - outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib, - generateClientParameters, generateCompilationFeedback, - generateCppHeader)}; -} - -std::unique_ptr -library_compile(LibrarySupport_C support, const char *module, - mlir::concretelang::CompilationOptions options) { - GET_OR_THROW_LLVM_EXPECTED(compilationResult, - support.support.compile(module, options)); - return std::move(*compilationResult); -} - -MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters -library_load_client_parameters( - LibrarySupport_C support, - mlir::concretelang::LibraryCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(clientParameters, - support.support.loadClientParameters(result)); - return *clientParameters; -} - -MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback -library_load_compilation_feedback( - LibrarySupport_C support, - mlir::concretelang::LibraryCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, - support.support.loadCompilationFeedback(result)); - return *compilationFeedback; -} - -MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda -library_load_server_lambda( - LibrarySupport_C support, - mlir::concretelang::LibraryCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(serverLambda, - support.support.loadServerLambda(result)); - return *serverLambda; -} - -MLIR_CAPI_EXPORTED std::unique_ptr -library_server_call(LibrarySupport_C support, - concretelang::serverlib::ServerLambda lambda, - concretelang::clientlib::PublicArguments &args, - concretelang::clientlib::EvaluationKeys &evaluationKeys) { - GET_OR_THROW_LLVM_EXPECTED( - publicResult, support.support.serverCall(lambda, args, evaluationKeys)); - return std::move(*publicResult); -} - -MLIR_CAPI_EXPORTED std::string -library_get_shared_lib_path(LibrarySupport_C support) { - return support.support.getSharedLibPath(); -} - -MLIR_CAPI_EXPORTED std::string -library_get_client_parameters_path(LibrarySupport_C support) { - return support.support.getClientParametersPath(); -} - -// Client Support bindings /////////////////////////////////////////////////// - -MLIR_CAPI_EXPORTED std::unique_ptr -key_set(concretelang::clientlib::ClientParameters clientParameters, - llvm::Optional cache) { - GET_OR_THROW_LLVM_EXPECTED( - ks, (mlir::concretelang::LambdaSupport::keySet(clientParameters, - cache))); - return std::move(*ks); -} - -MLIR_CAPI_EXPORTED std::unique_ptr -encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, - concretelang::clientlib::KeySet &keySet, - llvm::ArrayRef args) { - GET_OR_THROW_LLVM_EXPECTED( - publicArguments, - (mlir::concretelang::LambdaSupport::exportArguments( - clientParameters, keySet, args))); - return std::move(*publicArguments); -} - -MLIR_CAPI_EXPORTED lambdaArgument -decrypt_result(concretelang::clientlib::KeySet &keySet, - concretelang::clientlib::PublicResult &publicResult) { - GET_OR_THROW_LLVM_EXPECTED( - result, mlir::concretelang::typedResult< - std::unique_ptr>( - keySet, publicResult)); - lambdaArgument result_{std::move(*result)}; - return result_; -} - -MLIR_CAPI_EXPORTED std::unique_ptr -publicArgumentsUnserialize( - mlir::concretelang::ClientParameters &clientParameters, - const std::string &buffer) { - std::stringstream istream(buffer); - auto argsOrError = concretelang::clientlib::PublicArguments::unserialize( - clientParameters, istream); - if (!argsOrError) { - throw std::runtime_error(argsOrError.error().mesg); - } - return std::move(argsOrError.value()); -} - -MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize( - concretelang::clientlib::PublicArguments &publicArguments) { - - std::ostringstream buffer(std::ios::binary); - auto voidOrError = publicArguments.serialize(buffer); - if (!voidOrError) { - throw std::runtime_error(voidOrError.error().mesg); - } - return buffer.str(); -} - -MLIR_CAPI_EXPORTED std::unique_ptr -publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters, - const std::string &buffer) { - std::stringstream istream(buffer); - auto publicResultOrError = concretelang::clientlib::PublicResult::unserialize( - clientParameters, istream); - if (!publicResultOrError) { - throw std::runtime_error(publicResultOrError.error().mesg); - } - return std::move(publicResultOrError.value()); -} - -MLIR_CAPI_EXPORTED std::string -publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) { - std::ostringstream buffer(std::ios::binary); - auto voidOrError = publicResult.serialize(buffer); - if (!voidOrError) { - throw std::runtime_error(voidOrError.error().mesg); - } - return buffer.str(); -} - -MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys -evaluationKeysUnserialize(const std::string &buffer) { - std::stringstream istream(buffer); - - concretelang::clientlib::EvaluationKeys evaluationKeys; - concretelang::clientlib::operator>>(istream, evaluationKeys); - - if (istream.fail()) { - throw std::runtime_error("Cannot read evaluation keys"); - } - - return evaluationKeys; -} - -MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize( - concretelang::clientlib::EvaluationKeys &evaluationKeys) { - std::ostringstream buffer(std::ios::binary); - concretelang::clientlib::operator<<(buffer, evaluationKeys); - return buffer.str(); -} - -MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters -clientParametersUnserialize(const std::string &json) { - GET_OR_THROW_LLVM_EXPECTED( - clientParams, - llvm::json::parse(json)); - return clientParams.get(); -} - -MLIR_CAPI_EXPORTED std::string -clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms) { - llvm::json::Value value(params); - std::string jsonParams; - llvm::raw_string_ostream buffer(jsonParams); - buffer << value; - return jsonParams; -} - -void terminateParallelization() { -#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED - _dfr_terminate(); -#endif -} - -std::string roundTrip(const char *module) { - std::shared_ptr ccx = - mlir::concretelang::CompilationContext::createShared(); - mlir::concretelang::CompilerEngine ce{ccx}; - - std::string backingString; - llvm::raw_string_ostream os(backingString); - - llvm::Expected - retOrErr = ce.compile( - module, mlir::concretelang::CompilerEngine::Target::ROUND_TRIP); - if (!retOrErr) { - os << "MLIR parsing failed: " << llvm::toString(retOrErr.takeError()); - throw std::runtime_error(os.str()); - } - - retOrErr->mlirModuleRef->get().print(os); - return os.str(); -} - -bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) { - return lambda_arg.ptr->isa>>() || - lambda_arg.ptr->isa>>() || - lambda_arg.ptr->isa>>() || - lambda_arg.ptr->isa>>(); -} - -template -std::vector copyTensorLambdaArgumentTo64bitsvector( - mlir::concretelang::TensorLambdaArgument< - mlir::concretelang::IntLambdaArgument> *tensor) { - auto numElements = tensor->getNumElements(); - if (!numElements) { - std::string backingString; - llvm::raw_string_ostream os(backingString); - os << "Couldn't get size of tensor: " - << llvm::toString(std::move(numElements.takeError())); - throw std::runtime_error(os.str()); - } - std::vector res; - res.reserve(*numElements); - T *data = tensor->getValue(); - for (size_t i = 0; i < *numElements; i++) { - res.push_back(data[i]); - } - return res; -} - -std::vector lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) { - if (auto arg = - lambda_arg.ptr->dyn_cast>>()) { - llvm::Expected sizeOrErr = arg->getNumElements(); - if (!sizeOrErr) { - std::string backingString; - llvm::raw_string_ostream os(backingString); - os << "Couldn't get size of tensor: " - << llvm::toString(sizeOrErr.takeError()); - throw std::runtime_error(os.str()); +CompilationResult compilerEngineCompile(CompilerEngine engine, + MlirStringRef module, + CompilationTarget target) { + std::string module_str(module.data, module.length); + if (target == ROUND_TRIP) { + auto retOrError = unwrap(engine)->compile( + module_str, mlir::concretelang::CompilerEngine::Target::ROUND_TRIP); + if (!retOrError) { + // TODO: access the MlirContext + // mlir::emitError(mlir::UnknownLoc::get(unwrap(engine)) << "azeza"; + return wrap( + (mlir::concretelang::CompilerEngine::CompilationResult *)nullptr); } - std::vector data(arg->getValue(), arg->getValue() + *sizeOrErr); - return data; + return wrap(new mlir::concretelang::CompilerEngine::CompilationResult( + std::move(retOrError.get()))); } - if (auto arg = - lambda_arg.ptr->dyn_cast>>()) { - return copyTensorLambdaArgumentTo64bitsvector(arg); - } - if (auto arg = - lambda_arg.ptr->dyn_cast>>()) { - return copyTensorLambdaArgumentTo64bitsvector(arg); - } - if (auto arg = - lambda_arg.ptr->dyn_cast>>()) { - return copyTensorLambdaArgumentTo64bitsvector(arg); - } - throw std::invalid_argument( - "LambdaArgument isn't a tensor or has an unsupported bitwidth"); + return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)nullptr); } -std::vector -lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) { - if (auto arg = - lambda_arg.ptr->dyn_cast>>()) { - return arg->getDimensions(); - } - if (auto arg = - lambda_arg.ptr->dyn_cast>>()) { - return arg->getDimensions(); - } - if (auto arg = - lambda_arg.ptr->dyn_cast>>()) { - return arg->getDimensions(); - } - if (auto arg = - lambda_arg.ptr->dyn_cast>>()) { - return arg->getDimensions(); - } - throw std::invalid_argument( - "LambdaArgument isn't a tensor, should " - "be a TensorLambdaArgument>"); +/// CompilationResult CAPI +void compilationResultDestroy(CompilationResult result) { + delete unwrap(result); } -bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) { - return lambda_arg.ptr->isa>(); +MlirStringRef compilationResultGetModuleString(CompilationResult result) { + // print the module into a string + std::string moduleString; + llvm::raw_string_ostream os(moduleString); + unwrap(result)->mlirModuleRef->get().print(os); + // allocate buffer and copy module string + char *buffer = new char[moduleString.length() + 1]; + strcpy(buffer, moduleString.c_str()); + return mlirStringRefCreate(buffer, moduleString.length()); } -uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) { - mlir::concretelang::IntLambdaArgument *arg = - lambda_arg.ptr - ->dyn_cast>(); - if (arg == nullptr) { - throw std::invalid_argument("LambdaArgument isn't a scalar, should " - "be an IntLambdaArgument"); - } - return arg->getValue(); -} - -lambdaArgument lambdaArgumentFromTensorU8(std::vector data, - std::vector dimensions) { - lambdaArgument tensor_arg{ - std::make_shared>>(data, dimensions)}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorU16(std::vector data, - std::vector dimensions) { - lambdaArgument tensor_arg{ - std::make_shared>>(data, dimensions)}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorU32(std::vector data, - std::vector dimensions) { - lambdaArgument tensor_arg{ - std::make_shared>>(data, dimensions)}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromTensorU64(std::vector data, - std::vector dimensions) { - lambdaArgument tensor_arg{ - std::make_shared>>(data, dimensions)}; - return tensor_arg; -} - -lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) { - lambdaArgument scalar_arg{ - std::make_shared>( - scalar)}; - return scalar_arg; +void compilationResultDestroyModuleString(MlirStringRef str) { + delete str.data; }