mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(compiler): Move the keyset generation from CompilerEngine to JitCompilerEngine
This commit is contained in:
committed by
Andi Drebes
parent
1187cfbd62
commit
41cba63113
@@ -8,7 +8,6 @@
|
||||
#include <mlir/IR/MLIRContext.h>
|
||||
#include <zamalang/Conversion/Utils/GlobalFHEContext.h>
|
||||
#include <zamalang/Support/ClientParameters.h>
|
||||
#include <zamalang/Support/KeySet.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
@@ -43,7 +42,6 @@ public:
|
||||
|
||||
llvm::Optional<mlir::OwningModuleRef> mlirModuleRef;
|
||||
llvm::Optional<mlir::zamalang::ClientParameters> clientParameters;
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
|
||||
|
||||
@@ -93,8 +91,8 @@ public:
|
||||
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(),
|
||||
clientParametersFuncName(), verifyDiagnostics(false),
|
||||
generateKeySet(false), generateClientParameters(false),
|
||||
parametrizeMidLFHE(true), compilationContext(compilationContext) {}
|
||||
generateClientParameters(false), parametrizeMidLFHE(true),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
llvm::Expected<CompilationResult> compile(llvm::StringRef s, Target target);
|
||||
|
||||
@@ -107,7 +105,6 @@ public:
|
||||
void setMaxEintPrecision(size_t v);
|
||||
void setMaxMANP(size_t v);
|
||||
void setVerifyDiagnostics(bool v);
|
||||
void setGenerateKeySet(bool v);
|
||||
void setGenerateClientParameters(bool v);
|
||||
void setParametrizeMidLFHE(bool v);
|
||||
void setClientParametersFuncName(const llvm::StringRef &name);
|
||||
@@ -117,7 +114,6 @@ protected:
|
||||
llvm::Optional<size_t> overrideMaxMANP;
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
bool verifyDiagnostics;
|
||||
bool generateKeySet;
|
||||
bool generateClientParameters;
|
||||
bool parametrizeMidLFHE;
|
||||
|
||||
|
||||
@@ -74,8 +74,6 @@ void CompilerEngine::setVerifyDiagnostics(bool v) {
|
||||
this->verifyDiagnostics = v;
|
||||
}
|
||||
|
||||
void CompilerEngine::setGenerateKeySet(bool v) { this->generateKeySet = v; }
|
||||
|
||||
void CompilerEngine::setGenerateClientParameters(bool v) {
|
||||
this->generateClientParameters = v;
|
||||
}
|
||||
@@ -349,22 +347,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) {
|
||||
res.clientParameters = clientParametersOrErr.get();
|
||||
}
|
||||
|
||||
// Generate Key set if requested
|
||||
if (this->generateKeySet) {
|
||||
if (!res.clientParameters.hasValue()) {
|
||||
return StreamStringError("Generation of keyset requested without request "
|
||||
"for generation of client parameters");
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::KeySet>> keySetOrErr =
|
||||
mlir::zamalang::KeySet::generate(*res.clientParameters, 0, 0);
|
||||
|
||||
if (auto err = keySetOrErr.takeError())
|
||||
return std::move(err);
|
||||
|
||||
res.keySet = std::move(*keySetOrErr);
|
||||
}
|
||||
|
||||
// MLIR canonical dialects -> LLVM Dialect
|
||||
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module,
|
||||
false)
|
||||
|
||||
@@ -64,7 +64,6 @@ llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
|
||||
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
|
||||
this->setGenerateKeySet(true);
|
||||
this->setGenerateClientParameters(true);
|
||||
this->setClientParametersFuncName(funcName);
|
||||
|
||||
@@ -95,11 +94,25 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
|
||||
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
|
||||
mlir::zamalang::JITLambda::create(funcName, module, optPipeline);
|
||||
|
||||
// Generate the KeySet for encrypting lambda arguments, decrypting lambda
|
||||
// results
|
||||
if (!compResOrErr->clientParameters.hasValue()) {
|
||||
return StreamStringError("Cannot generate the keySet since client "
|
||||
"parameters has not been computed");
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::KeySet>> keySetOrErr =
|
||||
mlir::zamalang::KeySet::generate(*compResOrErr->clientParameters, 0, 0);
|
||||
|
||||
if (auto err = keySetOrErr.takeError())
|
||||
return std::move(err);
|
||||
|
||||
if (!lambdaOrErr)
|
||||
return std::move(lambdaOrErr.takeError());
|
||||
|
||||
return Lambda{this->compilationContext, std::move(lambdaOrErr.get()),
|
||||
std::move(compResOrErr->keySet)};
|
||||
std::move(*keySetOrErr)};
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user