mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
cleanup(compiler): Remove references to JitCompilerEngine to the main
This commit is contained in:
@@ -20,6 +20,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
@@ -30,13 +31,16 @@
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Runtime/runtime_api.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/JitLambdaSupport.h"
|
||||
#include "concretelang/Support/LLVMEmitFile.h"
|
||||
#include "concretelang/Support/Pipeline.h"
|
||||
#include "concretelang/Support/logging.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
namespace clientlib = concretelang::clientlib;
|
||||
|
||||
enum Action {
|
||||
ROUND_TRIP,
|
||||
DUMP_FHE,
|
||||
@@ -153,7 +157,7 @@ llvm::cl::opt<bool> dataflowParallelize(
|
||||
llvm::cl::opt<std::string>
|
||||
funcName("funcname",
|
||||
llvm::cl::desc("Name of the function to compile, default 'main'"),
|
||||
llvm::cl::init<std::string>("main"));
|
||||
llvm::cl::init<std::string>(""));
|
||||
|
||||
llvm::cl::list<uint64_t>
|
||||
jitArgs("jit-args",
|
||||
@@ -262,46 +266,35 @@ mlir::LogicalResult processInputBuffer(
|
||||
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
|
||||
llvm::Optional<llvm::ArrayRef<int64_t>> fhelinalgTileSizes,
|
||||
bool autoParallelize, bool loopParallelize, bool dataflowParallelize,
|
||||
llvm::Optional<mlir::concretelang::KeySetCache> keySetCache,
|
||||
llvm::raw_ostream &os,
|
||||
llvm::Optional<clientlib::KeySetCache> keySetCache, llvm::raw_ostream &os,
|
||||
std::shared_ptr<mlir::concretelang::CompilerEngine::Library> outputLib) {
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
|
||||
mlir::concretelang::CompilationContext::createShared();
|
||||
|
||||
mlir::concretelang::JitCompilerEngine ce{ccx};
|
||||
mlir::concretelang::CompilationOptions options;
|
||||
|
||||
ce.setVerifyDiagnostics(verifyDiagnostics);
|
||||
ce.setAutoParallelize(autoParallelize);
|
||||
ce.setLoopParallelize(loopParallelize);
|
||||
ce.setDataflowParallelize(dataflowParallelize);
|
||||
if (cmdline::passes.size() != 0) {
|
||||
ce.setEnablePass([](mlir::Pass *pass) {
|
||||
return std::any_of(
|
||||
cmdline::passes.begin(), cmdline::passes.end(),
|
||||
[&](const std::string &p) { return pass->getArgument() == p; });
|
||||
});
|
||||
}
|
||||
options.verifyDiagnostics = verifyDiagnostics;
|
||||
options.autoParallelize = autoParallelize;
|
||||
options.loopParallelize = loopParallelize;
|
||||
options.dataflowParallelize = dataflowParallelize;
|
||||
|
||||
if (overrideMaxEintPrecision.hasValue())
|
||||
ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue());
|
||||
if (overrideMaxEintPrecision.hasValue() && overrideMaxMANP.hasValue())
|
||||
options.v0FHEConstraints = {
|
||||
overrideMaxMANP.hasValue(),
|
||||
overrideMaxEintPrecision.hasValue(),
|
||||
};
|
||||
|
||||
if (overrideMaxMANP.hasValue())
|
||||
ce.setMaxMANP(overrideMaxMANP.getValue());
|
||||
if (!funcName.empty())
|
||||
options.clientParametersFuncName = funcName;
|
||||
|
||||
ce.setClientParametersFuncName(funcName);
|
||||
if (fhelinalgTileSizes.hasValue())
|
||||
ce.setFHELinalgTileSizes(*fhelinalgTileSizes);
|
||||
options.fhelinalgTileSizes = *fhelinalgTileSizes;
|
||||
|
||||
if (action == Action::JIT_INVOKE) {
|
||||
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
ce.buildLambda(std::move(buffer), funcName, keySetCache);
|
||||
|
||||
if (!lambdaOrErr) {
|
||||
mlir::concretelang::log_error()
|
||||
<< "Failed to JIT-compile " << funcName << ": "
|
||||
<< llvm::toString(lambdaOrErr.takeError());
|
||||
return mlir::failure();
|
||||
}
|
||||
auto lambdaOrErr =
|
||||
mlir::concretelang::ClientServer<mlir::concretelang::JitLambdaSupport>::
|
||||
create(buffer->getBuffer(), options, keySetCache,
|
||||
mlir::concretelang::JitLambdaSupport());
|
||||
|
||||
llvm::Expected<uint64_t> resOrErr = (*lambdaOrErr)(jitArgs);
|
||||
|
||||
@@ -314,6 +307,16 @@ mlir::LogicalResult processInputBuffer(
|
||||
|
||||
os << *resOrErr << "\n";
|
||||
} else {
|
||||
mlir::concretelang::CompilerEngine ce{ccx};
|
||||
ce.setCompilationOptions(options);
|
||||
|
||||
if (cmdline::passes.size() != 0) {
|
||||
ce.setEnablePass([](mlir::Pass *pass) {
|
||||
return std::any_of(
|
||||
cmdline::passes.begin(), cmdline::passes.end(),
|
||||
[&](const std::string &p) { return pass->getArgument() == p; });
|
||||
});
|
||||
}
|
||||
enum mlir::concretelang::CompilerEngine::Target target;
|
||||
|
||||
switch (action) {
|
||||
@@ -412,10 +415,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
if (!cmdline::fhelinalgTileSizes.empty())
|
||||
fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes);
|
||||
|
||||
llvm::Optional<mlir::concretelang::KeySetCache> jitKeySetCache;
|
||||
llvm::Optional<clientlib::KeySetCache> jitKeySetCache;
|
||||
if (!cmdline::jitKeySetCachePath.empty()) {
|
||||
jitKeySetCache =
|
||||
mlir::concretelang::KeySetCache(cmdline::jitKeySetCachePath);
|
||||
jitKeySetCache = clientlib::KeySetCache(cmdline::jitKeySetCachePath);
|
||||
}
|
||||
|
||||
// In case of compilation to library, the real output is the library.
|
||||
|
||||
Reference in New Issue
Block a user