feat(compiler): Output client parameters when compile to a library

close #198
This commit is contained in:
rudy
2021-12-29 11:34:54 +01:00
committed by Quentin Bourgerie
parent a4e8227692
commit b8bd38dd6c
26 changed files with 889 additions and 271 deletions

View File

@@ -19,6 +19,7 @@
#include <mlir/Support/ToolUtilities.h>
#include <sstream>
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
@@ -30,7 +31,6 @@
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/KeySet.h"
#include "concretelang/Support/LLVMEmitFile.h"
#include "concretelang/Support/Pipeline.h"
#include "concretelang/Support/logging.h"
@@ -112,7 +112,7 @@ static llvm::cl::opt<enum Action> action(
"Lower to LLVM-IR, optimize and dump result")),
llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke",
"Lower and JIT-compile input module and invoke "
"function specified with --jit-funcname")),
"function specified with --funcname")),
llvm::cl::values(clEnumValN(Action::COMPILE, "compile",
"Lower to LLVM-IR, compile to a file")));
@@ -133,10 +133,10 @@ llvm::cl::opt<bool> autoParallelize(
llvm::cl::desc("Generate (and execute if JIT) parallel code"),
llvm::cl::init(false));
llvm::cl::opt<std::string> jitFuncName(
"jit-funcname",
llvm::cl::desc("Name of the function to execute, default 'main'"),
llvm::cl::init<std::string>("main"));
llvm::cl::opt<std::string>
funcName("funcname",
llvm::cl::desc("Name of the function to compile, default 'main'"),
llvm::cl::init<std::string>("main"));
llvm::cl::list<uint64_t>
jitArgs("jit-args",
@@ -216,7 +216,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
// The parameter `action` specifies how the buffer should be processed
// and thus defines the output.
//
// If the specified action involves JIT compilation, `jitFuncName`
// If the specified action involves JIT compilation, `funcName`
// designates the function to JIT compile. This function is invoked
// using the parameters given in `jitArgs`.
//
@@ -239,7 +239,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
// Compilation output is written to the stream specified by `os`.
mlir::LogicalResult processInputBuffer(
std::unique_ptr<llvm::MemoryBuffer> buffer, std::string sourceFileName,
enum Action action, const std::string &jitFuncName,
enum Action action, const std::string &funcName,
llvm::ArrayRef<uint64_t> jitArgs,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
@@ -269,17 +269,18 @@ mlir::LogicalResult processInputBuffer(
if (overrideMaxMANP.hasValue())
ce.setMaxMANP(overrideMaxMANP.getValue());
ce.setClientParametersFuncName(funcName);
if (fhelinalgTileSizes.hasValue())
ce.setFHELinalgTileSizes(*fhelinalgTileSizes);
if (action == Action::JIT_INVOKE) {
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
ce.buildLambda(std::move(buffer), jitFuncName, keySetCache);
ce.buildLambda(std::move(buffer), funcName, keySetCache);
if (!lambdaOrErr) {
mlir::concretelang::log_error()
<< "Failed to JIT-compile " << jitFuncName << ": "
<< llvm::toString(std::move(lambdaOrErr.takeError()));
<< "Failed to JIT-compile " << funcName << ": "
<< llvm::toString(lambdaOrErr.takeError());
return mlir::failure();
}
@@ -287,7 +288,7 @@ mlir::LogicalResult processInputBuffer(
if (!resOrErr) {
mlir::concretelang::log_error()
<< "Failed to JIT-invoke " << jitFuncName << " with arguments "
<< "Failed to JIT-invoke " << funcName << " with arguments "
<< jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError()));
return mlir::failure();
}
@@ -425,11 +426,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
auto process = [&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(
std::move(inputBuffer), fileName, cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, fhelinalgTileSizes,
cmdline::autoParallelize, jitKeySetCache, os, outputLib);
std::move(inputBuffer), fileName, cmdline::action, cmdline::funcName,
cmdline::jitArgs, cmdline::assumeMaxEintPrecision,
cmdline::assumeMaxMANP, cmdline::verifyDiagnostics,
fhelinalgTileSizes, cmdline::autoParallelize, jitKeySetCache, os,
outputLib);
};
auto &os = output->os();
auto res = mlir::failure();
@@ -446,12 +447,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
}
if (cmdline::action == Action::COMPILE) {
auto libPath = outputLib->emitShared();
if (!libPath) {
return mlir::failure();
}
libPath = outputLib->emitStatic();
if (!libPath) {
auto err = outputLib->emitArtifacts();
if (err) {
return mlir::failure();
}
}