cleanup(capi/python-bindings): Remove reference to JitCompilerEngine

This commit is contained in:
Quentin Bourgerie
2022-03-17 15:47:07 +01:00
parent fc996fc698
commit 1620259807
3 changed files with 2 additions and 129 deletions

View File

@@ -8,7 +8,6 @@
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/JitLambdaSupport.h"
#include "concretelang/Support/LibraryLambdaSupport.h"
#include "mlir-c/IR.h"
@@ -18,12 +17,6 @@
extern "C" {
#endif
// C wrapper of the mlir::concretelang::JitCompilerEngine::Lambda
struct lambda {
mlir::concretelang::JitCompilerEngine::Lambda *ptr;
};
typedef struct lambda lambda;
// C wrapper of the mlir::concretelang::LambdaArgument
struct lambdaArgument {
std::shared_ptr<mlir::concretelang::LambdaArgument> ptr;
@@ -106,23 +99,9 @@ MLIR_CAPI_EXPORTED lambdaArgument
decrypt_result(concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::PublicResult &publicResult);
// Build lambda from a textual representation of an MLIR module
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if
// not null) as a shared library during compilation, a path to activate the
// use a cache for encryption keys for test purpose (unsecure).
MLIR_CAPI_EXPORTED mlir::concretelang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName,
const char *runtimeLibPath, const char *keySetCachePath,
bool autoParallelize, bool loopParallelize, bool dfParallelize);
// Parse then print a textual representation of an MLIR module
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
// Execute the lambda with executionArguments and get the result as
// lambdaArgument
MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l,
executionArguments args);
// Terminate parallelization
MLIR_CAPI_EXPORTED void terminateParallelization();

View File

@@ -7,7 +7,6 @@
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/JitLambdaSupport.h"
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
@@ -21,14 +20,9 @@
#include <string>
using mlir::concretelang::CompilationOptions;
using mlir::concretelang::JitCompilerEngine;
using mlir::concretelang::JitLambdaSupport;
using mlir::concretelang::LambdaArgument;
const char *noEmptyStringPtr(std::string &s) {
return (s.empty()) ? nullptr : s.c_str();
}
/// Populate the compiler API python module.
void mlir::concretelang::python::populateCompilerAPISubmodule(
pybind11::module &m) {
@@ -37,11 +31,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
m.def("round_trip",
[](std::string mlir_input) { return roundTrip(mlir_input.c_str()); });
m.def("library",
[](std::string library_path, std::vector<std::string> mlir_modules) {
return library(library_path, mlir_modules);
});
m.def("terminate_parallelization", &terminateParallelization);
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
@@ -157,7 +146,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
clientlib::PublicResult &publicResult) {
return decrypt_result(keySet, publicResult);
});
pybind11::class_<KeySetCache>(m, "KeySetCache")
pybind11::class_<clientlib::KeySetCache>(m, "KeySetCache")
.def(pybind11::init<std::string &>());
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters");
@@ -203,13 +192,4 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def("get_scalar", [](lambdaArgument &lambda_arg) {
return lambdaArgumentGetScalar(lambda_arg);
});
pybind11::class_<JitCompilerEngine::Lambda>(m, "Lambda")
.def("invoke", [](JitCompilerEngine::Lambda &py_lambda,
std::vector<lambdaArgument> args) {
// wrap and call CAPI
lambda c_lambda{&py_lambda};
executionArguments a{args.data(), args.size()};
return invokeLambda(c_lambda, a);
});
}

View File

@@ -10,11 +10,8 @@
#include "concretelang/Runtime/runtime_api.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/JitLambdaSupport.h"
using mlir::concretelang::JitCompilerEngine;
#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
auto VARNAME = EXPECTED; \
if (auto err = VARNAME.takeError()) { \
@@ -139,40 +136,6 @@ decrypt_result(concretelang::clientlib::KeySet &keySet,
return std::move(result_);
}
mlir::concretelang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName,
const char *runtimeLibPath, const char *keySetCachePath,
bool autoParallelize, bool loopParallelize, bool dfParallelize) {
// Set the runtime library path if not nullptr
llvm::Optional<llvm::StringRef> runtimeLibPathOptional = {};
if (runtimeLibPath != nullptr)
runtimeLibPathOptional = runtimeLibPath;
mlir::concretelang::JitCompilerEngine engine;
// Set parallelization flags
engine.setAutoParallelize(autoParallelize);
engine.setLoopParallelize(loopParallelize);
engine.setDataflowParallelize(dfParallelize);
using KeySetCache = mlir::concretelang::KeySetCache;
using optKeySetCache = llvm::Optional<mlir::concretelang::KeySetCache>;
auto cacheOpt = optKeySetCache();
if (keySetCachePath != nullptr) {
cacheOpt = KeySetCache(std::string(keySetCachePath));
}
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
engine.buildLambda(module, funcName, cacheOpt, runtimeLibPathOptional);
if (!lambdaOrErr) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Compilation failed: "
<< llvm::toString(std::move(lambdaOrErr.takeError()));
throw std::runtime_error(os.str());
}
return std::move(*lambdaOrErr);
}
void terminateParallelization() {
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
_dfr_terminate();
@@ -182,42 +145,10 @@ void terminateParallelization() {
#endif
}
lambdaArgument invokeLambda(lambda l, executionArguments args) {
mlir::concretelang::JitCompilerEngine::Lambda *lambda_ptr =
(mlir::concretelang::JitCompilerEngine::Lambda *)l.ptr;
if (args.size != lambda_ptr->getNumArguments()) {
throw std::invalid_argument("wrong number of arguments");
}
// Set the integer/tensor arguments
std::vector<mlir::concretelang::LambdaArgument *> lambdaArgumentsRef;
for (auto i = 0u; i < args.size; i++) {
lambdaArgumentsRef.push_back(args.data[i].ptr.get());
}
// Run lambda
llvm::Expected<std::unique_ptr<mlir::concretelang::LambdaArgument>>
resOrError =
(*lambda_ptr)
.
operator()<std::unique_ptr<mlir::concretelang::LambdaArgument>>(
llvm::ArrayRef<mlir::concretelang::LambdaArgument *>(
lambdaArgumentsRef));
if (!resOrError) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Lambda invocation failed: "
<< llvm::toString(std::move(resOrError.takeError()));
throw std::runtime_error(os.str());
}
lambdaArgument result{std::move(*resOrError)};
return std::move(result);
}
std::string roundTrip(const char *module) {
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
mlir::concretelang::JitCompilerEngine ce{ccx};
mlir::concretelang::CompilerEngine ce{ccx};
std::string backingString;
llvm::raw_string_ostream os(backingString);
@@ -330,20 +261,3 @@ lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) {
scalar)};
return scalar_arg;
}
template <class T>
std::runtime_error library_error(std::string prefix, llvm::Expected<T> &error) {
return std::runtime_error(prefix + llvm::toString(error.takeError()));
}
std::string library(std::string libraryPath,
std::vector<std::string> mlir_modules) {
using namespace mlir::concretelang;
JitCompilerEngine ce{CompilationContext::createShared()};
auto lib = ce.compile(mlir_modules, libraryPath);
if (!lib) {
throw std::runtime_error("Can't link: " + llvm::toString(lib.takeError()));
}
return lib->sharedLibraryPath;
}