refactor(compiler): Move the keyset generation from CompilerEngine to JitCompilerEngine

This commit is contained in:
Quentin Bourgerie
2021-10-22 09:41:23 +02:00
committed by Andi Drebes
parent 1187cfbd62
commit 41cba63113
3 changed files with 17 additions and 26 deletions

View File

@@ -8,7 +8,6 @@
#include <mlir/IR/MLIRContext.h>
#include <zamalang/Conversion/Utils/GlobalFHEContext.h>
#include <zamalang/Support/ClientParameters.h>
#include <zamalang/Support/KeySet.h>
namespace mlir {
namespace zamalang {
@@ -43,7 +42,6 @@ public:
llvm::Optional<mlir::OwningModuleRef> mlirModuleRef;
llvm::Optional<mlir::zamalang::ClientParameters> clientParameters;
std::unique_ptr<mlir::zamalang::KeySet> keySet;
std::unique_ptr<llvm::Module> llvmModule;
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
@@ -93,8 +91,8 @@ public:
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(),
clientParametersFuncName(), verifyDiagnostics(false),
generateKeySet(false), generateClientParameters(false),
parametrizeMidLFHE(true), compilationContext(compilationContext) {}
generateClientParameters(false), parametrizeMidLFHE(true),
compilationContext(compilationContext) {}
llvm::Expected<CompilationResult> compile(llvm::StringRef s, Target target);
@@ -107,7 +105,6 @@ public:
void setMaxEintPrecision(size_t v);
void setMaxMANP(size_t v);
void setVerifyDiagnostics(bool v);
void setGenerateKeySet(bool v);
void setGenerateClientParameters(bool v);
void setParametrizeMidLFHE(bool v);
void setClientParametersFuncName(const llvm::StringRef &name);
@@ -117,7 +114,6 @@ protected:
llvm::Optional<size_t> overrideMaxMANP;
llvm::Optional<std::string> clientParametersFuncName;
bool verifyDiagnostics;
bool generateKeySet;
bool generateClientParameters;
bool parametrizeMidLFHE;

View File

@@ -74,8 +74,6 @@ void CompilerEngine::setVerifyDiagnostics(bool v) {
this->verifyDiagnostics = v;
}
void CompilerEngine::setGenerateKeySet(bool v) { this->generateKeySet = v; }
void CompilerEngine::setGenerateClientParameters(bool v) {
this->generateClientParameters = v;
}
@@ -349,22 +347,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) {
res.clientParameters = clientParametersOrErr.get();
}
// Generate Key set if requested
if (this->generateKeySet) {
if (!res.clientParameters.hasValue()) {
return StreamStringError("Generation of keyset requested without request "
"for generation of client parameters");
}
llvm::Expected<std::unique_ptr<mlir::zamalang::KeySet>> keySetOrErr =
mlir::zamalang::KeySet::generate(*res.clientParameters, 0, 0);
if (auto err = keySetOrErr.takeError())
return std::move(err);
res.keySet = std::move(*keySetOrErr);
}
// MLIR canonical dialects -> LLVM Dialect
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module,
false)

View File

@@ -64,7 +64,6 @@ llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
this->setGenerateKeySet(true);
this->setGenerateClientParameters(true);
this->setClientParametersFuncName(funcName);
@@ -95,11 +94,25 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
mlir::zamalang::JITLambda::create(funcName, module, optPipeline);
// Generate the KeySet for encrypting lambda arguments, decrypting lambda
// results
if (!compResOrErr->clientParameters.hasValue()) {
return StreamStringError("Cannot generate the keySet since client "
"parameters has not been computed");
}
llvm::Expected<std::unique_ptr<mlir::zamalang::KeySet>> keySetOrErr =
mlir::zamalang::KeySet::generate(*compResOrErr->clientParameters, 0, 0);
if (auto err = keySetOrErr.takeError())
return std::move(err);
if (!lambdaOrErr)
return std::move(lambdaOrErr.takeError());
return Lambda{this->compilationContext, std::move(lambdaOrErr.get()),
std::move(compResOrErr->keySet)};
std::move(*keySetOrErr)};
}
} // namespace zamalang
} // namespace mlir