diff --git a/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h b/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h index 891121627..3236ee183 100644 --- a/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h +++ b/compiler/include/zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h @@ -9,7 +9,7 @@ namespace zamalang { /// Create a pass to convert `LowLFHE` operators to function call to the /// `ConcreteCAPI` std::unique_ptr> -createConvertLowLFHEToConcreteCAPIPass(); +createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize); } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Support/ClientParameters.h b/compiler/include/zamalang/Support/ClientParameters.h index d2fbac27b..2075bd362 100644 --- a/compiler/include/zamalang/Support/ClientParameters.h +++ b/compiler/include/zamalang/Support/ClientParameters.h @@ -73,7 +73,7 @@ struct ClientParameters { }; llvm::Expected -createClientParametersForV0(V0Parameter *v0Param, Precision precision, +createClientParametersForV0(V0Parameter &v0Param, Precision precision, llvm::StringRef name, mlir::ModuleOp module); } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Support/CompilerTools.h b/compiler/include/zamalang/Support/CompilerTools.h index 3e9b67b75..3b8e5f411 100644 --- a/compiler/include/zamalang/Support/CompilerTools.h +++ b/compiler/include/zamalang/Support/CompilerTools.h @@ -26,7 +26,7 @@ public: /// The given module MLIR operation would be modified and the constraints set. static mlir::LogicalResult lowerHLFHEToMlirStdsDialect( mlir::MLIRContext &context, mlir::Operation *module, - FHECircuitConstraint &constraint, + FHECircuitConstraint &constraint, V0Parameter &v0Parameter, llvm::function_ref enablePass = [](std::string pass) { return true; }); diff --git a/compiler/include/zamalang/Support/V0Parameters.h b/compiler/include/zamalang/Support/V0Parameters.h index 45b887d18..c2bb186b7 100644 --- a/compiler/include/zamalang/Support/V0Parameters.h +++ b/compiler/include/zamalang/Support/V0Parameters.h @@ -15,6 +15,8 @@ typedef struct V0Parameter { size_t ksLevel; size_t ksLogBase; + V0Parameter() {} + V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel, size_t brLogBase, size_t ksLevel, size_t ksLogBase) : k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel), diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp index ae2d07c56..d320df5e7 100644 --- a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -30,9 +30,10 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context, mlir::StringRef funcName, mlir::StringRef allocName, + uint64_t lweSize, mlir::PatternBenefit benefit = 1) : mlir::OpRewritePattern(context, benefit), funcName(funcName), - allocName(allocName) {} + allocName(allocName), lweSize(lweSize) {} mlir::LogicalResult static insertForwardDeclaration( Op op, mlir::PatternRewriter &rewriter, llvm::StringRef funcName, @@ -92,17 +93,16 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { // Replace the operation with a call to the `funcName` { // Create the err value - auto err = rewriter.create(op.getLoc(), errType); + auto errOp = rewriter.create(op.getLoc(), errType); // Add the call to the allocation - // TODO - 2018 - auto lweSize = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(2048)); - mlir::SmallVector allocOperands{err, lweSize}; + auto lweSizeOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(lweSize)); + mlir::SmallVector allocOperands{errOp, lweSizeOp}; auto alloc = rewriter.replaceOpWithNewOp( op, allocName, op.getType(), allocOperands); // Add err and allocated value to operands - mlir::SmallVector newOperands{err, alloc.getResult(0)}; + mlir::SmallVector newOperands{errOp, alloc.getResult(0)}; for (auto operand : op->getOperands()) { newOperands.push_back(operand); } @@ -115,21 +115,24 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { private: std::string funcName; std::string allocName; + uint64_t lweSize; }; /// Populate the RewritePatternSet with all patterns that rewrite LowLFHE /// operators to the corresponding function call to the `Concrete C API`. -void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) { +void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns, uint64_t lweSize) { patterns.add>( patterns.getContext(), "add_lwe_ciphertexts_u64", - "allocate_lwe_ciphertext_u64"); + "allocate_lwe_ciphertext_u64", lweSize); } namespace { struct LowLFHEToConcreteCAPIPass : public LowLFHEToConcreteCAPIBase { + LowLFHEToConcreteCAPIPass(uint64_t lweSize): lweSize(lweSize){}; void runOnOperation() final; + uint64_t lweSize; }; } // namespace @@ -142,7 +145,7 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() { // Setup rewrite patterns mlir::RewritePatternSet patterns(&getContext()); - populateLowLFHEToConcreteCAPICall(patterns); + populateLowLFHEToConcreteCAPICall(patterns, lweSize); // Apply the conversion mlir::ModuleOp op = getOperation(); @@ -154,8 +157,8 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() { namespace mlir { namespace zamalang { std::unique_ptr> -createConvertLowLFHEToConcreteCAPIPass() { - return std::make_unique(); +createConvertLowLFHEToConcreteCAPIPass(uint64_t lweSize) { + return std::make_unique(lweSize); } } // namespace zamalang } // namespace mlir \ No newline at end of file diff --git a/compiler/lib/Support/ClientParameters.cpp b/compiler/lib/Support/ClientParameters.cpp index 423646550..5f1d0b0c7 100644 --- a/compiler/lib/Support/ClientParameters.cpp +++ b/compiler/lib/Support/ClientParameters.cpp @@ -45,13 +45,13 @@ llvm::Expected gateFromMLIRType(std::string secretKeyID, } llvm::Expected -createClientParametersForV0(V0Parameter *v0Param, Precision precision, +createClientParametersForV0(V0Parameter &v0Param, Precision precision, llvm::StringRef name, mlir::ModuleOp module) { // Static client parameters from global parameters for v0 ClientParameters c{ .secretKeys{ - {"small", {.size = v0Param->nSmall}}, - {"big", {.size = v0Param->k * (1 << v0Param->polynomialSize)}}, + {"small", {.size = v0Param.nSmall}}, + {"big", {.size = v0Param.k * (1 << v0Param.polynomialSize)}}, }, .bootstrapKeys{ { @@ -59,9 +59,9 @@ createClientParametersForV0(V0Parameter *v0Param, Precision precision, { .inputSecretKeyID = "small", .outputSecretKeyID = "big", - .level = v0Param->brLevel, - .baseLog = v0Param->brLogBase, - .k = v0Param->k, + .level = v0Param.brLevel, + .baseLog = v0Param.brLogBase, + .k = v0Param.k, // TODO - Compute variance, wait for security estimator .variance = 0, }, @@ -73,8 +73,8 @@ createClientParametersForV0(V0Parameter *v0Param, Precision precision, { .inputSecretKeyID = "big", .outputSecretKeyID = "small", - .level = v0Param->ksLevel, - .baseLog = v0Param->ksLogBase, + .level = v0Param.ksLevel, + .baseLog = v0Param.ksLogBase, // TODO - Compute variance, wait for security estimator .variance = 0, }, diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index c4780acb1..af241b7be 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -32,18 +32,19 @@ void addFilteredPassToPassManager( mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( mlir::MLIRContext &context, mlir::Operation *module, - FHECircuitConstraint &constraint, + FHECircuitConstraint &constraint, V0Parameter &v0Parameter, llvm::function_ref enablePass) { mlir::PassManager pm(&context); + constraint = defaultGlobalFHECircuitConstraint; + v0Parameter = *getV0Parameter(constraint.norm2, constraint.p); // Add all passes to lower from HLFHE to LLVM Dialect addFilteredPassToPassManager( pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass); - constraint = defaultGlobalFHECircuitConstraint; + pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(1 << v0Parameter.polynomialSize), enablePass); // Run the passes if (pm.run(module).failed()) { diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 0c56fa967..7a908ce15 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -175,33 +175,30 @@ processInputBuffer(mlir::MLIRContext &context, // Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit. mlir::zamalang::FHECircuitConstraint constraint; + mlir::zamalang::V0Parameter v0Parameter; LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n"); if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( - context, *module, constraint, enablePass) + context, *module, constraint, v0Parameter, enablePass) .failed()) { return mlir::failure(); } LOG_VERBOSE("### Global FHE constraint: {norm2:" << constraint.norm2 << ", p:" << constraint.p << "}\n"); - - // Retreive the parameters for the v0 approach - mlir::zamalang::V0Parameter *fheParameter = - mlir::zamalang::getV0Parameter(constraint.norm2, constraint.p); LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: " - << fheParameter->k - << ", polynomialSize: " << fheParameter->polynomialSize - << ", nSmall: " << fheParameter->nSmall - << ", brLevel: " << fheParameter->brLevel - << ", brLogBase: " << fheParameter->brLogBase - << ", ksLevel: " << fheParameter->ksLevel - << ", polynomialSize: " << fheParameter->ksLogBase << "}\n"); + << v0Parameter.k + << ", polynomialSize: " << v0Parameter.polynomialSize + << ", nSmall: " << v0Parameter.nSmall + << ", brLevel: " << v0Parameter.brLevel + << ", brLogBase: " << v0Parameter.brLogBase + << ", ksLevel: " << v0Parameter.ksLevel + << ", polynomialSize: " << v0Parameter.ksLogBase << "}\n"); // Generate the keySet std::unique_ptr keySet; if (cmdline::generateKeySet || cmdline::runJit) { // Create the client parameters auto clientParameter = mlir::zamalang::createClientParametersForV0( - fheParameter, constraint.p, cmdline::jitFuncname, *module); + v0Parameter, constraint.p, cmdline::jitFuncname, *module); if (auto err = clientParameter.takeError()) { LOG_ERROR("cannot generate client parameters: " << err << "\n"); return mlir::failure();