mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): Output client parameters when compile to a library
close #198
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
#include <mlir/Support/ToolUtilities.h>
|
||||
#include <sstream>
|
||||
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
@@ -30,7 +31,6 @@
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include "concretelang/Support/KeySet.h"
|
||||
#include "concretelang/Support/LLVMEmitFile.h"
|
||||
#include "concretelang/Support/Pipeline.h"
|
||||
#include "concretelang/Support/logging.h"
|
||||
@@ -112,7 +112,7 @@ static llvm::cl::opt<enum Action> action(
|
||||
"Lower to LLVM-IR, optimize and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke",
|
||||
"Lower and JIT-compile input module and invoke "
|
||||
"function specified with --jit-funcname")),
|
||||
"function specified with --funcname")),
|
||||
llvm::cl::values(clEnumValN(Action::COMPILE, "compile",
|
||||
"Lower to LLVM-IR, compile to a file")));
|
||||
|
||||
@@ -133,10 +133,10 @@ llvm::cl::opt<bool> autoParallelize(
|
||||
llvm::cl::desc("Generate (and execute if JIT) parallel code"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<std::string> jitFuncName(
|
||||
"jit-funcname",
|
||||
llvm::cl::desc("Name of the function to execute, default 'main'"),
|
||||
llvm::cl::init<std::string>("main"));
|
||||
llvm::cl::opt<std::string>
|
||||
funcName("funcname",
|
||||
llvm::cl::desc("Name of the function to compile, default 'main'"),
|
||||
llvm::cl::init<std::string>("main"));
|
||||
|
||||
llvm::cl::list<uint64_t>
|
||||
jitArgs("jit-args",
|
||||
@@ -216,7 +216,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
// The parameter `action` specifies how the buffer should be processed
|
||||
// and thus defines the output.
|
||||
//
|
||||
// If the specified action involves JIT compilation, `jitFuncName`
|
||||
// If the specified action involves JIT compilation, `funcName`
|
||||
// designates the function to JIT compile. This function is invoked
|
||||
// using the parameters given in `jitArgs`.
|
||||
//
|
||||
@@ -239,7 +239,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
// Compilation output is written to the stream specified by `os`.
|
||||
mlir::LogicalResult processInputBuffer(
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer, std::string sourceFileName,
|
||||
enum Action action, const std::string &jitFuncName,
|
||||
enum Action action, const std::string &funcName,
|
||||
llvm::ArrayRef<uint64_t> jitArgs,
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
|
||||
@@ -269,17 +269,18 @@ mlir::LogicalResult processInputBuffer(
|
||||
if (overrideMaxMANP.hasValue())
|
||||
ce.setMaxMANP(overrideMaxMANP.getValue());
|
||||
|
||||
ce.setClientParametersFuncName(funcName);
|
||||
if (fhelinalgTileSizes.hasValue())
|
||||
ce.setFHELinalgTileSizes(*fhelinalgTileSizes);
|
||||
|
||||
if (action == Action::JIT_INVOKE) {
|
||||
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
ce.buildLambda(std::move(buffer), jitFuncName, keySetCache);
|
||||
ce.buildLambda(std::move(buffer), funcName, keySetCache);
|
||||
|
||||
if (!lambdaOrErr) {
|
||||
mlir::concretelang::log_error()
|
||||
<< "Failed to JIT-compile " << jitFuncName << ": "
|
||||
<< llvm::toString(std::move(lambdaOrErr.takeError()));
|
||||
<< "Failed to JIT-compile " << funcName << ": "
|
||||
<< llvm::toString(lambdaOrErr.takeError());
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
@@ -287,7 +288,7 @@ mlir::LogicalResult processInputBuffer(
|
||||
|
||||
if (!resOrErr) {
|
||||
mlir::concretelang::log_error()
|
||||
<< "Failed to JIT-invoke " << jitFuncName << " with arguments "
|
||||
<< "Failed to JIT-invoke " << funcName << " with arguments "
|
||||
<< jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError()));
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -425,11 +426,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
auto process = [&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
|
||||
llvm::raw_ostream &os) {
|
||||
return processInputBuffer(
|
||||
std::move(inputBuffer), fileName, cmdline::action,
|
||||
cmdline::jitFuncName, cmdline::jitArgs,
|
||||
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
|
||||
cmdline::verifyDiagnostics, fhelinalgTileSizes,
|
||||
cmdline::autoParallelize, jitKeySetCache, os, outputLib);
|
||||
std::move(inputBuffer), fileName, cmdline::action, cmdline::funcName,
|
||||
cmdline::jitArgs, cmdline::assumeMaxEintPrecision,
|
||||
cmdline::assumeMaxMANP, cmdline::verifyDiagnostics,
|
||||
fhelinalgTileSizes, cmdline::autoParallelize, jitKeySetCache, os,
|
||||
outputLib);
|
||||
};
|
||||
auto &os = output->os();
|
||||
auto res = mlir::failure();
|
||||
@@ -446,12 +447,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
}
|
||||
|
||||
if (cmdline::action == Action::COMPILE) {
|
||||
auto libPath = outputLib->emitShared();
|
||||
if (!libPath) {
|
||||
return mlir::failure();
|
||||
}
|
||||
libPath = outputLib->emitStatic();
|
||||
if (!libPath) {
|
||||
auto err = outputLib->emitArtifacts();
|
||||
if (err) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user