diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index e9c496b1a..a2eb2ed28 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -8,7 +8,6 @@ #include #include #include -#include namespace mlir { namespace zamalang { @@ -43,7 +42,6 @@ public: llvm::Optional mlirModuleRef; llvm::Optional clientParameters; - std::unique_ptr keySet; std::unique_ptr llvmModule; llvm::Optional fheContext; @@ -93,8 +91,8 @@ public: CompilerEngine(std::shared_ptr compilationContext) : overrideMaxEintPrecision(), overrideMaxMANP(), clientParametersFuncName(), verifyDiagnostics(false), - generateKeySet(false), generateClientParameters(false), - parametrizeMidLFHE(true), compilationContext(compilationContext) {} + generateClientParameters(false), parametrizeMidLFHE(true), + compilationContext(compilationContext) {} llvm::Expected compile(llvm::StringRef s, Target target); @@ -107,7 +105,6 @@ public: void setMaxEintPrecision(size_t v); void setMaxMANP(size_t v); void setVerifyDiagnostics(bool v); - void setGenerateKeySet(bool v); void setGenerateClientParameters(bool v); void setParametrizeMidLFHE(bool v); void setClientParametersFuncName(const llvm::StringRef &name); @@ -117,7 +114,6 @@ protected: llvm::Optional overrideMaxMANP; llvm::Optional clientParametersFuncName; bool verifyDiagnostics; - bool generateKeySet; bool generateClientParameters; bool parametrizeMidLFHE; diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 50d98b721..6224569a5 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -74,8 +74,6 @@ void CompilerEngine::setVerifyDiagnostics(bool v) { this->verifyDiagnostics = v; } -void CompilerEngine::setGenerateKeySet(bool v) { this->generateKeySet = v; } - void CompilerEngine::setGenerateClientParameters(bool v) { this->generateClientParameters = v; } @@ -349,22 +347,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) { res.clientParameters = clientParametersOrErr.get(); } - // Generate Key set if requested - if (this->generateKeySet) { - if (!res.clientParameters.hasValue()) { - return StreamStringError("Generation of keyset requested without request " - "for generation of client parameters"); - } - - llvm::Expected> keySetOrErr = - mlir::zamalang::KeySet::generate(*res.clientParameters, 0, 0); - - if (auto err = keySetOrErr.takeError()) - return std::move(err); - - res.keySet = std::move(*keySetOrErr); - } - // MLIR canonical dialects -> LLVM Dialect if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module, false) diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp index 05359ac4f..95626182d 100644 --- a/compiler/lib/Support/JitCompilerEngine.cpp +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -64,7 +64,6 @@ llvm::Expected JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) { MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); - this->setGenerateKeySet(true); this->setGenerateClientParameters(true); this->setClientParametersFuncName(funcName); @@ -95,11 +94,25 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) { llvm::Expected> lambdaOrErr = mlir::zamalang::JITLambda::create(funcName, module, optPipeline); + // Generate the KeySet for encrypting lambda arguments, decrypting lambda + // results + if (!compResOrErr->clientParameters.hasValue()) { + return StreamStringError("Cannot generate the keySet since client " + "parameters has not been computed"); + } + + llvm::Expected> keySetOrErr = + mlir::zamalang::KeySet::generate(*compResOrErr->clientParameters, 0, 0); + + if (auto err = keySetOrErr.takeError()) + return std::move(err); + if (!lambdaOrErr) return std::move(lambdaOrErr.takeError()); return Lambda{this->compilationContext, std::move(lambdaOrErr.get()), - std::move(compResOrErr->keySet)}; + std::move(*keySetOrErr)}; } + } // namespace zamalang } // namespace mlir