mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
cleanup(capi/python-bindings): Remove reference to JitCompilerEngine
This commit is contained in:
@@ -8,7 +8,6 @@
|
||||
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/JitLambdaSupport.h"
|
||||
#include "concretelang/Support/LibraryLambdaSupport.h"
|
||||
#include "mlir-c/IR.h"
|
||||
@@ -18,12 +17,6 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// C wrapper of the mlir::concretelang::JitCompilerEngine::Lambda
|
||||
struct lambda {
|
||||
mlir::concretelang::JitCompilerEngine::Lambda *ptr;
|
||||
};
|
||||
typedef struct lambda lambda;
|
||||
|
||||
// C wrapper of the mlir::concretelang::LambdaArgument
|
||||
struct lambdaArgument {
|
||||
std::shared_ptr<mlir::concretelang::LambdaArgument> ptr;
|
||||
@@ -106,23 +99,9 @@ MLIR_CAPI_EXPORTED lambdaArgument
|
||||
decrypt_result(concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult);
|
||||
|
||||
// Build lambda from a textual representation of an MLIR module
|
||||
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if
|
||||
// not null) as a shared library during compilation, a path to activate the
|
||||
// use a cache for encryption keys for test purpose (unsecure).
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath, const char *keySetCachePath,
|
||||
bool autoParallelize, bool loopParallelize, bool dfParallelize);
|
||||
|
||||
// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
// Execute the lambda with executionArguments and get the result as
|
||||
// lambdaArgument
|
||||
MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l,
|
||||
executionArguments args);
|
||||
|
||||
// Terminate parallelization
|
||||
MLIR_CAPI_EXPORTED void terminateParallelization();
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "concretelang-c/Support/CompilerEngine.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/JitLambdaSupport.h"
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
@@ -21,14 +20,9 @@
|
||||
#include <string>
|
||||
|
||||
using mlir::concretelang::CompilationOptions;
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
using mlir::concretelang::JitLambdaSupport;
|
||||
using mlir::concretelang::LambdaArgument;
|
||||
|
||||
const char *noEmptyStringPtr(std::string &s) {
|
||||
return (s.empty()) ? nullptr : s.c_str();
|
||||
}
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
pybind11::module &m) {
|
||||
@@ -37,11 +31,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
m.def("round_trip",
|
||||
[](std::string mlir_input) { return roundTrip(mlir_input.c_str()); });
|
||||
|
||||
m.def("library",
|
||||
[](std::string library_path, std::vector<std::string> mlir_modules) {
|
||||
return library(library_path, mlir_modules);
|
||||
});
|
||||
|
||||
m.def("terminate_parallelization", &terminateParallelization);
|
||||
|
||||
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
|
||||
@@ -157,7 +146,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
clientlib::PublicResult &publicResult) {
|
||||
return decrypt_result(keySet, publicResult);
|
||||
});
|
||||
pybind11::class_<KeySetCache>(m, "KeySetCache")
|
||||
pybind11::class_<clientlib::KeySetCache>(m, "KeySetCache")
|
||||
.def(pybind11::init<std::string &>());
|
||||
|
||||
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters");
|
||||
@@ -203,13 +192,4 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def("get_scalar", [](lambdaArgument &lambda_arg) {
|
||||
return lambdaArgumentGetScalar(lambda_arg);
|
||||
});
|
||||
|
||||
pybind11::class_<JitCompilerEngine::Lambda>(m, "Lambda")
|
||||
.def("invoke", [](JitCompilerEngine::Lambda &py_lambda,
|
||||
std::vector<lambdaArgument> args) {
|
||||
// wrap and call CAPI
|
||||
lambda c_lambda{&py_lambda};
|
||||
executionArguments a{args.data(), args.size()};
|
||||
return invokeLambda(c_lambda, a);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -10,11 +10,8 @@
|
||||
#include "concretelang/Runtime/runtime_api.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/JitLambdaSupport.h"
|
||||
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
|
||||
#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
|
||||
auto VARNAME = EXPECTED; \
|
||||
if (auto err = VARNAME.takeError()) { \
|
||||
@@ -139,40 +136,6 @@ decrypt_result(concretelang::clientlib::KeySet &keySet,
|
||||
return std::move(result_);
|
||||
}
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath, const char *keySetCachePath,
|
||||
bool autoParallelize, bool loopParallelize, bool dfParallelize) {
|
||||
// Set the runtime library path if not nullptr
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPathOptional = {};
|
||||
if (runtimeLibPath != nullptr)
|
||||
runtimeLibPathOptional = runtimeLibPath;
|
||||
mlir::concretelang::JitCompilerEngine engine;
|
||||
|
||||
// Set parallelization flags
|
||||
engine.setAutoParallelize(autoParallelize);
|
||||
engine.setLoopParallelize(loopParallelize);
|
||||
engine.setDataflowParallelize(dfParallelize);
|
||||
|
||||
using KeySetCache = mlir::concretelang::KeySetCache;
|
||||
using optKeySetCache = llvm::Optional<mlir::concretelang::KeySetCache>;
|
||||
auto cacheOpt = optKeySetCache();
|
||||
if (keySetCachePath != nullptr) {
|
||||
cacheOpt = KeySetCache(std::string(keySetCachePath));
|
||||
}
|
||||
|
||||
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(module, funcName, cacheOpt, runtimeLibPathOptional);
|
||||
if (!lambdaOrErr) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Compilation failed: "
|
||||
<< llvm::toString(std::move(lambdaOrErr.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
return std::move(*lambdaOrErr);
|
||||
}
|
||||
|
||||
void terminateParallelization() {
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
_dfr_terminate();
|
||||
@@ -182,42 +145,10 @@ void terminateParallelization() {
|
||||
#endif
|
||||
}
|
||||
|
||||
lambdaArgument invokeLambda(lambda l, executionArguments args) {
|
||||
mlir::concretelang::JitCompilerEngine::Lambda *lambda_ptr =
|
||||
(mlir::concretelang::JitCompilerEngine::Lambda *)l.ptr;
|
||||
|
||||
if (args.size != lambda_ptr->getNumArguments()) {
|
||||
throw std::invalid_argument("wrong number of arguments");
|
||||
}
|
||||
// Set the integer/tensor arguments
|
||||
std::vector<mlir::concretelang::LambdaArgument *> lambdaArgumentsRef;
|
||||
for (auto i = 0u; i < args.size; i++) {
|
||||
lambdaArgumentsRef.push_back(args.data[i].ptr.get());
|
||||
}
|
||||
// Run lambda
|
||||
llvm::Expected<std::unique_ptr<mlir::concretelang::LambdaArgument>>
|
||||
resOrError =
|
||||
(*lambda_ptr)
|
||||
.
|
||||
operator()<std::unique_ptr<mlir::concretelang::LambdaArgument>>(
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *>(
|
||||
lambdaArgumentsRef));
|
||||
|
||||
if (!resOrError) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Lambda invocation failed: "
|
||||
<< llvm::toString(std::move(resOrError.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
lambdaArgument result{std::move(*resOrError)};
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
std::string roundTrip(const char *module) {
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
mlir::concretelang::JitCompilerEngine ce{ccx};
|
||||
mlir::concretelang::CompilerEngine ce{ccx};
|
||||
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
@@ -330,20 +261,3 @@ lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) {
|
||||
scalar)};
|
||||
return scalar_arg;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::runtime_error library_error(std::string prefix, llvm::Expected<T> &error) {
|
||||
return std::runtime_error(prefix + llvm::toString(error.takeError()));
|
||||
}
|
||||
|
||||
std::string library(std::string libraryPath,
|
||||
std::vector<std::string> mlir_modules) {
|
||||
using namespace mlir::concretelang;
|
||||
|
||||
JitCompilerEngine ce{CompilationContext::createShared()};
|
||||
auto lib = ce.compile(mlir_modules, libraryPath);
|
||||
if (!lib) {
|
||||
throw std::runtime_error("Can't link: " + llvm::toString(lib.takeError()));
|
||||
}
|
||||
return lib->sharedLibraryPath;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user