diff --git a/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h b/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h index 3236ee183..33427ce54 100644 --- a/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h +++ b/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h @@ -4,12 +4,14 @@ #include "mlir/Pass/Pass.h" +#include "zamalang/Conversion/Utils/GlobalFHEContext.h" + namespace mlir { namespace zamalang { /// Create a pass to convert `LowLFHE` operators to function call to the /// `ConcreteCAPI` std::unique_ptr> -createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize); +createConvertLowLFHEToConcreteCAPIPass(V0FHEContext &context); } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h index 1efbbead8..0753a64bd 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h @@ -20,14 +20,14 @@ LweCiphertextType convertTypeGLWEToLWE(mlir::MLIRContext *context, return LweCiphertextType::get(context); } -PlaintextType convertIntToPlaintextType(mlir::MLIRContext *context, - IntegerType &type) { - return PlaintextType::get(context, type.getWidth()); +PlaintextType convertPlaintextTypeFromGlwe(mlir::MLIRContext *context, + GLWECipherTextType &type) { + return PlaintextType::get(context, type.getP() + 1); } -CleartextType convertIntToCleartextType(mlir::MLIRContext *context, - IntegerType &type) { - return CleartextType::get(context, type.getWidth()); +CleartextType convertCleartextTypeFromGlwe(mlir::MLIRContext *context, + GLWECipherTextType &type) { + return CleartextType::get(context, type.getP() + 1); } template @@ -44,13 +44,11 @@ mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter rewriter, return op.getODSResults(0).front(); } -mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter, - mlir::Location loc, mlir::Value arg0, - mlir::Value arg1, - mlir::OpResult result) { - auto integer_type = arg1.getType().cast(); +mlir::Value createAddPlainLweCiphertextWithGlwe( + mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0, + mlir::Value arg1, mlir::OpResult result, GLWECipherTextType glwe) { PlaintextType encoded_type = - convertIntToPlaintextType(rewriter.getContext(), integer_type); + convertPlaintextTypeFromGlwe(rewriter.getContext(), glwe); // encode int into plaintext mlir::Value encoded = rewriter @@ -67,6 +65,15 @@ mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter, return op.getODSResults(0).front(); } +mlir::Value createAddPlainLweCiphertext(mlir::PatternRewriter rewriter, + mlir::Location loc, mlir::Value arg0, + mlir::Value arg1, + mlir::OpResult result) { + auto glwe = arg0.getType().cast(); + return createAddPlainLweCiphertextWithGlwe(rewriter, loc, arg0, arg1, result, + glwe); +} + mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0, mlir::Value arg1, mlir::OpResult result) { @@ -76,16 +83,17 @@ mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter rewriter, .create( loc, convertTypeGLWEToLWE(rewriter.getContext(), arg1_type), arg1) .result(); - return createAddPlainLweCiphertext(rewriter, loc, negated_arg1, arg0, result); + return createAddPlainLweCiphertextWithGlwe(rewriter, loc, negated_arg1, arg0, + result, arg1_type); } mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter, mlir::Location loc, mlir::Value arg0, mlir::Value arg1, mlir::OpResult result) { - auto integer_type = arg1.getType().cast(); + auto glwe = arg0.getType().cast(); CleartextType encoded_type = - convertIntToCleartextType(rewriter.getContext(), integer_type); + convertCleartextTypeFromGlwe(rewriter.getContext(), glwe); // encode int into plaintext mlir::Value encoded = rewriter .create( diff --git a/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h new file mode 100644 index 000000000..85aa2a0eb --- /dev/null +++ b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h @@ -0,0 +1,40 @@ +#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_ +#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_ +#include + +namespace mlir { +namespace zamalang { + +struct V0FHEConstraint { + size_t norm2; + size_t p; +}; + +struct V0Parameter { + size_t k; + size_t polynomialSize; + size_t nSmall; + size_t brLevel; + size_t brLogBase; + size_t ksLevel; + size_t ksLogBase; + + V0Parameter() {} + + V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel, + size_t brLogBase, size_t ksLevel, size_t ksLogBase) + : k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel), + brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {} + + size_t getNBigGlweSize() { return k * (1 << polynomialSize); } +}; + +struct V0FHEContext { + V0FHEConstraint constraint; + V0Parameter parameter; +}; + +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Support/ClientParameters.h b/compiler/include/zamalang/Support/ClientParameters.h index 2075bd362..af94647f7 100644 --- a/compiler/include/zamalang/Support/ClientParameters.h +++ b/compiler/include/zamalang/Support/ClientParameters.h @@ -73,8 +73,8 @@ struct ClientParameters { }; llvm::Expected -createClientParametersForV0(V0Parameter &v0Param, Precision precision, - llvm::StringRef name, mlir::ModuleOp module); +createClientParametersForV0(V0FHEContext context, llvm::StringRef name, + mlir::ModuleOp module); } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Support/CompilerTools.h b/compiler/include/zamalang/Support/CompilerTools.h index 3b8e5f411..6a33dd44c 100644 --- a/compiler/include/zamalang/Support/CompilerTools.h +++ b/compiler/include/zamalang/Support/CompilerTools.h @@ -12,13 +12,6 @@ namespace mlir { namespace zamalang { -/// For the v0 we compute a global constraint, this is defined here as the -/// high-level verification pass is not yet implemented. -struct FHECircuitConstraint { - size_t norm2; - size_t p; -}; - class CompilerTools { public: /// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir @@ -26,7 +19,7 @@ public: /// The given module MLIR operation would be modified and the constraints set. static mlir::LogicalResult lowerHLFHEToMlirStdsDialect( mlir::MLIRContext &context, mlir::Operation *module, - FHECircuitConstraint &constraint, V0Parameter &v0Parameter, + V0FHEContext &fheContext, llvm::function_ref enablePass = [](std::string pass) { return true; }); diff --git a/compiler/include/zamalang/Support/V0Parameters.h b/compiler/include/zamalang/Support/V0Parameters.h index c2bb186b7..78538dfc8 100644 --- a/compiler/include/zamalang/Support/V0Parameters.h +++ b/compiler/include/zamalang/Support/V0Parameters.h @@ -1,30 +1,13 @@ #ifndef ZAMALANG_SUPPORT_V0Parameter_H_ #define ZAMALANG_SUPPORT_V0Parameter_H_ +#include "zamalang/Conversion/Utils/GlobalFHEContext.h" #include namespace mlir { namespace zamalang { -typedef struct V0Parameter { - size_t k; - size_t polynomialSize; - size_t nSmall; - size_t brLevel; - size_t brLogBase; - size_t ksLevel; - size_t ksLogBase; - - V0Parameter() {} - - V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel, - size_t brLogBase, size_t ksLevel, size_t ksLogBase) - : k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel), - brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {} - -} V0Parameter; - -V0Parameter *getV0Parameter(size_t norm, size_t p); +V0Parameter *getV0Parameter(V0FHEConstraint constraint); } // namespace zamalang } // namespace mlir diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp index 62da06db3..5e5b78054 100644 --- a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -202,9 +202,10 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns, namespace { struct LowLFHEToConcreteCAPIPass : public LowLFHEToConcreteCAPIBase { - LowLFHEToConcreteCAPIPass(uint64_t lweSize) : lweSize(lweSize){}; + LowLFHEToConcreteCAPIPass(mlir::zamalang::V0FHEContext &fheContext) + : fheContext(fheContext){}; void runOnOperation() final; - uint64_t lweSize; + mlir::zamalang::V0FHEContext &fheContext; }; } // namespace @@ -217,6 +218,7 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() { // Setup rewrite patterns mlir::RewritePatternSet patterns(&getContext()); + auto lweSize = 1 << fheContext.parameter.polynomialSize; populateLowLFHEToConcreteCAPICall(patterns, lweSize); // Apply the conversion @@ -229,8 +231,8 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() { namespace mlir { namespace zamalang { std::unique_ptr> -createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize) { - return std::make_unique(lweSize); +createConvertLowLFHEToConcreteCAPIPass(V0FHEContext &fheContext) { + return std::make_unique(fheContext); } } // namespace zamalang } // namespace mlir \ No newline at end of file diff --git a/compiler/lib/Support/ClientParameters.cpp b/compiler/lib/Support/ClientParameters.cpp index c52adc019..841dbb1f2 100644 --- a/compiler/lib/Support/ClientParameters.cpp +++ b/compiler/lib/Support/ClientParameters.cpp @@ -45,13 +45,14 @@ llvm::Expected gateFromMLIRType(std::string secretKeyID, } llvm::Expected -createClientParametersForV0(V0Parameter &v0Param, Precision precision, - llvm::StringRef name, mlir::ModuleOp module) { +createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, + mlir::ModuleOp module) { + auto v0Param = fheContext.parameter; // Static client parameters from global parameters for v0 ClientParameters c{ .secretKeys{ {"small", {.size = v0Param.nSmall}}, - {"big", {.size = v0Param.k * (1 << v0Param.polynomialSize)}}, + {"big", {.size = v0Param.getNBigGlweSize()}}, }, .bootstrapKeys{ { @@ -92,7 +93,8 @@ createClientParametersForV0(V0Parameter &v0Param, Precision precision, } // For the v0 the precision is global - Encoding encoding = {.precision = precision}; + auto precision = fheContext.constraint.p; + Encoding encoding = {.precision = fheContext.constraint.p}; // Create input and output circuit gate parameters auto funcType = (*funcOp).getType(); diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 6e708fabb..69477f65f 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -30,18 +30,17 @@ CompilerEngine::compileFHE(std::string mlir_input) { return llvm::make_error("mlir parsing failed", llvm::inconvertibleErrorCode()); } - mlir::zamalang::FHECircuitConstraint constraint; - mlir::zamalang::V0Parameter v0Parameter; + mlir::zamalang::V0FHEContext fheContext; // Lower to MLIR Std if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( - *context, module_ref.get(), constraint, v0Parameter) + *context, module_ref.get(), fheContext) .failed()) { return llvm::make_error("failed to lower to MLIR Std", llvm::inconvertibleErrorCode()); } // Create the client parameters auto clientParameter = mlir::zamalang::createClientParametersForV0( - v0Parameter, constraint.p, "main", module_ref.get()); + fheContext, "main", module_ref.get()); if (auto err = clientParameter.takeError()) { return llvm::make_error( "cannot generate client parameters", llvm::inconvertibleErrorCode()); diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index d6617cc9f..9334b1186 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -9,7 +9,7 @@ namespace mlir { namespace zamalang { // This is temporary while we doesn't yet have the high-level verification pass -FHECircuitConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, .p = 7}; +V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, .p = 7}; void initLLVMNativeTarget() { // Initialize LLVM targets. @@ -32,23 +32,23 @@ void addFilteredPassToPassManager( mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( mlir::MLIRContext &context, mlir::Operation *module, - FHECircuitConstraint &constraint, V0Parameter &v0Parameter, + V0FHEContext &fheContext, llvm::function_ref enablePass) { mlir::PassManager pm(&context); - constraint = defaultGlobalFHECircuitConstraint; - v0Parameter = *getV0Parameter(constraint.norm2, constraint.p); + fheContext.constraint = defaultGlobalFHECircuitConstraint; + fheContext.parameter = *getV0Parameter(fheContext.constraint); // Add all passes to lower from HLFHE to LLVM Dialect addFilteredPassToPassManager( pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); + addFilteredPassToPassManager( + pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass); addFilteredPassToPassManager( - pm, - mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass( - 1 << v0Parameter.polynomialSize), + pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(fheContext), enablePass); // Run the passes diff --git a/compiler/lib/Support/V0Parameters.cpp b/compiler/lib/Support/V0Parameters.cpp index 2973bc231..d77ae3cca 100644 --- a/compiler/lib/Support/V0Parameters.cpp +++ b/compiler/lib/Support/V0Parameters.cpp @@ -115,15 +115,15 @@ V0Parameter parameters[NORM2_MAX][P_MAX] = { V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(0, 0, 0, 0, 0, 0, 0), V0Parameter(0, 0, 0, 0, 0, 0, 0)}}; -V0Parameter *getV0Parameter(size_t norm, size_t p) { - if (norm > NORM2_MAX) { +V0Parameter *getV0Parameter(V0FHEConstraint constraint) { + if (constraint.norm2 > NORM2_MAX) { return nullptr; } - if (p > P_MAX) { + if (constraint.p > P_MAX) { return nullptr; } // - 1 is an offset as norm and p are in [1, ...] and not [0, ...] - auto param = ¶meters[norm - 1][p - 1]; + auto param = ¶meters[constraint.norm2 - 1][constraint.p - 1]; if (param->k == 0) { return nullptr; } diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 9f40c2439..550ff4ae1 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -175,30 +175,32 @@ processInputBuffer(mlir::MLIRContext &context, }; // Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit. - mlir::zamalang::FHECircuitConstraint constraint; - mlir::zamalang::V0Parameter v0Parameter; + mlir::zamalang::V0FHEContext fheContext; LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n"); if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( - context, *module, constraint, v0Parameter, enablePass) + context, *module, fheContext, enablePass) .failed()) { return mlir::failure(); } - LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:" - << constraint.p << "}\n"); + LOG_VERBOSE("### Global FHE constraint: {norm2:" + << fheContext.constraint.norm2 + << ", p:" << fheContext.constraint.p << "}\n"); LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: " - << v0Parameter.k - << ", polynomialSize: " << v0Parameter.polynomialSize - << ", nSmall: " << v0Parameter.nSmall << ", brLevel: " - << v0Parameter.brLevel << ", brLogBase: " << v0Parameter.brLogBase - << ", ksLevel: " << v0Parameter.ksLevel - << ", polynomialSize: " << v0Parameter.ksLogBase << "}\n"); + << fheContext.parameter.k + << ", polynomialSize: " << fheContext.parameter.polynomialSize + << ", nSmall: " << fheContext.parameter.nSmall + << ", brLevel: " << fheContext.parameter.brLevel + << ", brLogBase: " << fheContext.parameter.brLogBase + << ", ksLevel: " << fheContext.parameter.ksLevel + << ", polynomialSize: " << fheContext.parameter.ksLogBase + << "}\n"); // Generate the keySet std::unique_ptr keySet; if (cmdline::generateKeySet || cmdline::runJit) { // Create the client parameters auto clientParameter = mlir::zamalang::createClientParametersForV0( - v0Parameter, constraint.p, cmdline::jitFuncname, *module); + fheContext, cmdline::jitFuncname, *module); if (auto err = clientParameter.takeError()) { LOG_ERROR("cannot generate client parameters: " << err << "\n"); return mlir::failure();