// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "llvm/ADT/SmallString.h" #include "concretelang/Bindings/Python/CompilerEngine.h" #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/ClientLib/Serializers.h" #include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/JITSupport.h" #include "concretelang/Support/Jit.h" #define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \ auto VARNAME = EXPECTED; \ if (auto err = VARNAME.takeError()) { \ throw std::runtime_error(llvm::toString(std::move(err))); \ } // JIT Support bindings /////////////////////////////////////////////////////// MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath) { auto opt = runtimeLibPath.empty() ? std::nullopt : std::optional(runtimeLibPath); return JITSupport_Py{mlir::concretelang::JITSupport(opt)}; } MLIR_CAPI_EXPORTED std::unique_ptr jit_compile(JITSupport_Py support, const char *module, mlir::concretelang::CompilationOptions options) { GET_OR_THROW_LLVM_EXPECTED(compilationResult, support.support.compile(module, options)); return std::move(*compilationResult); } MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters jit_load_client_parameters(JITSupport_Py support, mlir::concretelang::JitCompilationResult &result) { GET_OR_THROW_LLVM_EXPECTED(clientParameters, support.support.loadClientParameters(result)); return *clientParameters; } MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback jit_load_compilation_feedback( JITSupport_Py support, mlir::concretelang::JitCompilationResult &result) { GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, support.support.loadCompilationFeedback(result)); return *compilationFeedback; } MLIR_CAPI_EXPORTED std::shared_ptr jit_load_server_lambda(JITSupport_Py support, mlir::concretelang::JitCompilationResult &result) { GET_OR_THROW_LLVM_EXPECTED(serverLambda, support.support.loadServerLambda(result)); return *serverLambda; } MLIR_CAPI_EXPORTED std::unique_ptr jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda, concretelang::clientlib::PublicArguments &args, concretelang::clientlib::EvaluationKeys &evaluationKeys) { GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args, evaluationKeys)); return std::move(*publicResult); } // Library Support bindings /////////////////////////////////////////////////// MLIR_CAPI_EXPORTED LibrarySupport_Py library_support(const char *outputPath, const char *runtimeLibraryPath, bool generateSharedLib, bool generateStaticLib, bool generateClientParameters, bool generateCompilationFeedback, bool generateCppHeader) { return LibrarySupport_Py{mlir::concretelang::LibrarySupport( outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib, generateClientParameters, generateCompilationFeedback, generateCppHeader)}; } MLIR_CAPI_EXPORTED std::unique_ptr library_compile(LibrarySupport_Py support, const char *module, mlir::concretelang::CompilationOptions options) { GET_OR_THROW_LLVM_EXPECTED(compilationResult, support.support.compile(module, options)); return std::move(*compilationResult); } MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters library_load_client_parameters( LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result) { GET_OR_THROW_LLVM_EXPECTED(clientParameters, support.support.loadClientParameters(result)); return *clientParameters; } MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback library_load_compilation_feedback( LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result) { GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, support.support.loadCompilationFeedback(result)); return *compilationFeedback; } MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda library_load_server_lambda( LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result) { GET_OR_THROW_LLVM_EXPECTED(serverLambda, support.support.loadServerLambda(result)); return *serverLambda; } MLIR_CAPI_EXPORTED std::unique_ptr library_server_call(LibrarySupport_Py support, concretelang::serverlib::ServerLambda lambda, concretelang::clientlib::PublicArguments &args, concretelang::clientlib::EvaluationKeys &evaluationKeys) { GET_OR_THROW_LLVM_EXPECTED( publicResult, support.support.serverCall(lambda, args, evaluationKeys)); return std::move(*publicResult); } MLIR_CAPI_EXPORTED std::string library_get_shared_lib_path(LibrarySupport_Py support) { return support.support.getSharedLibPath(); } MLIR_CAPI_EXPORTED std::string library_get_client_parameters_path(LibrarySupport_Py support) { return support.support.getClientParametersPath(); } // Client Support bindings /////////////////////////////////////////////////// MLIR_CAPI_EXPORTED std::unique_ptr key_set(concretelang::clientlib::ClientParameters clientParameters, std::optional cache) { GET_OR_THROW_LLVM_EXPECTED( ks, (mlir::concretelang::LambdaSupport::keySet(clientParameters, cache))); return std::move(*ks); } MLIR_CAPI_EXPORTED std::unique_ptr encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, concretelang::clientlib::KeySet &keySet, llvm::ArrayRef args) { GET_OR_THROW_LLVM_EXPECTED( publicArguments, (mlir::concretelang::LambdaSupport::exportArguments( clientParameters, keySet, args))); return std::move(*publicArguments); } MLIR_CAPI_EXPORTED lambdaArgument decrypt_result(concretelang::clientlib::KeySet &keySet, concretelang::clientlib::PublicResult &publicResult) { GET_OR_THROW_LLVM_EXPECTED( result, mlir::concretelang::typedResult< std::unique_ptr>( keySet, publicResult)); lambdaArgument result_{std::move(*result)}; return result_; } MLIR_CAPI_EXPORTED std::unique_ptr publicArgumentsUnserialize( mlir::concretelang::ClientParameters &clientParameters, const std::string &buffer) { std::stringstream istream(buffer); auto argsOrError = concretelang::clientlib::PublicArguments::unserialize( clientParameters, istream); if (!argsOrError) { throw std::runtime_error(argsOrError.error().mesg); } return std::move(argsOrError.value()); } MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize( concretelang::clientlib::PublicArguments &publicArguments) { std::ostringstream buffer(std::ios::binary); auto voidOrError = publicArguments.serialize(buffer); if (!voidOrError) { throw std::runtime_error(voidOrError.error().mesg); } return buffer.str(); } MLIR_CAPI_EXPORTED std::unique_ptr publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters, const std::string &buffer) { std::stringstream istream(buffer); auto publicResultOrError = concretelang::clientlib::PublicResult::unserialize( clientParameters, istream); if (!publicResultOrError) { throw std::runtime_error(publicResultOrError.error().mesg); } return std::move(publicResultOrError.value()); } MLIR_CAPI_EXPORTED std::string publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) { std::ostringstream buffer(std::ios::binary); auto voidOrError = publicResult.serialize(buffer); if (!voidOrError) { throw std::runtime_error(voidOrError.error().mesg); } return buffer.str(); } MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys evaluationKeysUnserialize(const std::string &buffer) { std::stringstream istream(buffer); concretelang::clientlib::EvaluationKeys evaluationKeys = concretelang::clientlib::readEvaluationKeys(istream); if (istream.fail()) { throw std::runtime_error("Cannot read evaluation keys"); } return evaluationKeys; } MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize( concretelang::clientlib::EvaluationKeys &evaluationKeys) { std::ostringstream buffer(std::ios::binary); concretelang::clientlib::operator<<(buffer, evaluationKeys); return buffer.str(); } MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters clientParametersUnserialize(const std::string &json) { GET_OR_THROW_LLVM_EXPECTED( clientParams, llvm::json::parse(json)); return clientParams.get(); } MLIR_CAPI_EXPORTED std::string clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms) { llvm::json::Value value(params); std::string jsonParams; llvm::raw_string_ostream buffer(jsonParams); buffer << value; return jsonParams; } MLIR_CAPI_EXPORTED void terminateDataflowParallelization() { _dfr_terminate(); } MLIR_CAPI_EXPORTED void initDataflowParallelization() { mlir::concretelang::dfr::_dfr_set_required(true); } MLIR_CAPI_EXPORTED std::string roundTrip(const char *module) { std::shared_ptr ccx = mlir::concretelang::CompilationContext::createShared(); mlir::concretelang::CompilerEngine ce{ccx}; std::string backingString; llvm::raw_string_ostream os(backingString); llvm::Expected retOrErr = ce.compile( module, mlir::concretelang::CompilerEngine::Target::ROUND_TRIP); if (!retOrErr) { os << "MLIR parsing failed: " << llvm::toString(retOrErr.takeError()); throw std::runtime_error(os.str()); } retOrErr->mlirModuleRef->get().print(os); return os.str(); } MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) { return lambda_arg.ptr->isa>>() || lambda_arg.ptr->isa>>() || lambda_arg.ptr->isa>>() || lambda_arg.ptr->isa>>(); } template MLIR_CAPI_EXPORTED std::vector copyTensorLambdaArgumentTo64bitsvector( mlir::concretelang::TensorLambdaArgument< mlir::concretelang::IntLambdaArgument> *tensor) { auto numElements = tensor->getNumElements(); if (!numElements) { std::string backingString; llvm::raw_string_ostream os(backingString); os << "Couldn't get size of tensor: " << llvm::toString(std::move(numElements.takeError())); throw std::runtime_error(os.str()); } std::vector res; res.reserve(*numElements); T *data = tensor->getValue(); for (size_t i = 0; i < *numElements; i++) { res.push_back(data[i]); } return res; } MLIR_CAPI_EXPORTED std::vector lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) { if (auto arg = lambda_arg.ptr->dyn_cast>>()) { llvm::Expected sizeOrErr = arg->getNumElements(); if (!sizeOrErr) { std::string backingString; llvm::raw_string_ostream os(backingString); os << "Couldn't get size of tensor: " << llvm::toString(sizeOrErr.takeError()); throw std::runtime_error(os.str()); } std::vector data(arg->getValue(), arg->getValue() + *sizeOrErr); return data; } if (auto arg = lambda_arg.ptr->dyn_cast>>()) { return copyTensorLambdaArgumentTo64bitsvector(arg); } if (auto arg = lambda_arg.ptr->dyn_cast>>()) { return copyTensorLambdaArgumentTo64bitsvector(arg); } if (auto arg = lambda_arg.ptr->dyn_cast>>()) { return copyTensorLambdaArgumentTo64bitsvector(arg); } throw std::invalid_argument( "LambdaArgument isn't a tensor or has an unsupported bitwidth"); } MLIR_CAPI_EXPORTED std::vector lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) { if (auto arg = lambda_arg.ptr->dyn_cast>>()) { return arg->getDimensions(); } if (auto arg = lambda_arg.ptr->dyn_cast>>()) { return arg->getDimensions(); } if (auto arg = lambda_arg.ptr->dyn_cast>>()) { return arg->getDimensions(); } if (auto arg = lambda_arg.ptr->dyn_cast>>()) { return arg->getDimensions(); } throw std::invalid_argument( "LambdaArgument isn't a tensor, should " "be a TensorLambdaArgument>"); } MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) { return lambda_arg.ptr->isa>(); } MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) { mlir::concretelang::IntLambdaArgument *arg = lambda_arg.ptr ->dyn_cast>(); if (arg == nullptr) { throw std::invalid_argument("LambdaArgument isn't a scalar, should " "be an IntLambdaArgument"); } return arg->getValue(); } MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8( std::vector data, std::vector dimensions) { lambdaArgument tensor_arg{ std::make_shared>>(data, dimensions)}; return tensor_arg; } MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16( std::vector data, std::vector dimensions) { lambdaArgument tensor_arg{ std::make_shared>>(data, dimensions)}; return tensor_arg; } MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32( std::vector data, std::vector dimensions) { lambdaArgument tensor_arg{ std::make_shared>>(data, dimensions)}; return tensor_arg; } MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64( std::vector data, std::vector dimensions) { lambdaArgument tensor_arg{ std::make_shared>>(data, dimensions)}; return tensor_arg; } MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) { lambdaArgument scalar_arg{ std::make_shared>( scalar)}; return scalar_arg; }