diff --git a/compiler/include/zamalang/Support/CompilerTools.h b/compiler/include/zamalang/Support/CompilerTools.h index c69e27246..1017e0906 100644 --- a/compiler/include/zamalang/Support/CompilerTools.h +++ b/compiler/include/zamalang/Support/CompilerTools.h @@ -14,24 +14,29 @@ namespace zamalang { class CompilerTools { public: + struct LowerOptions { + llvm::function_ref enablePass; + bool verbose; + + LowerOptions() + : verbose(false), enablePass([](std::string pass) { return true; }){}; + }; + /// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir /// lowerable to llvm dialect. /// The given module MLIR operation would be modified and the constraints set. - static mlir::LogicalResult lowerHLFHEToMlirStdsDialect( - mlir::MLIRContext &context, mlir::Operation *module, - V0FHEContext &fheContext, - llvm::function_ref enablePass = [](std::string pass) { - return true; - }); + static mlir::LogicalResult + lowerHLFHEToMlirStdsDialect(mlir::MLIRContext &context, + mlir::Operation *module, V0FHEContext &fheContext, + LowerOptions options = LowerOptions()); /// lowerMlirStdsDialectToMlirLLVMDialect run all passes to lower MLIR /// dialects to MLIR LLVM dialect. The given module MLIR operation would be /// modified. - static mlir::LogicalResult lowerMlirStdsDialectToMlirLLVMDialect( - mlir::MLIRContext &context, mlir::Operation *module, - llvm::function_ref enablePass = [](std::string pass) { - return true; - }); + static mlir::LogicalResult + lowerMlirStdsDialectToMlirLLVMDialect(mlir::MLIRContext &context, + mlir::Operation *module, + LowerOptions options = LowerOptions()); static llvm::Expected> toLLVMModule(llvm::LLVMContext &llvmContext, mlir::ModuleOp &module, diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index a838a2880..5a8626217 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -36,25 +36,37 @@ void addFilteredPassToPassManager( mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( mlir::MLIRContext &context, mlir::Operation *module, - V0FHEContext &fheContext, - llvm::function_ref enablePass) { + V0FHEContext &fheContext, LowerOptions options) { mlir::PassManager pm(&context); + if (options.verbose) { + llvm::errs() << "##################################################\n"; + llvm::errs() << "### HLFHEToMlirStdsDialect pipeline\n"; + context.disableMultithreading(); + pm.enableIRPrinting(); + pm.enableStatistics(); + pm.enableTiming(); + pm.enableVerifier(); + } fheContext.constraint = defaultGlobalFHECircuitConstraint; fheContext.parameter = *getV0Parameter(fheContext.constraint); // Add all passes to lower from HLFHE to LLVM Dialect addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass); + pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), + options.enablePass); addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); + pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), + options.enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(fheContext), - enablePass); + options.enablePass); addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass); + pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), + options.enablePass); addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), enablePass); + pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), + options.enablePass); // Run the passes if (pm.run(module).failed()) { @@ -65,26 +77,37 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( } mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( - mlir::MLIRContext &context, mlir::Operation *module, - llvm::function_ref enablePass) { + mlir::MLIRContext &context, mlir::Operation *module, LowerOptions options) { mlir::PassManager pm(&context); + if (options.verbose) { + llvm::errs() << "##################################################\n"; + llvm::errs() << "### MlirStdsDialectToMlirLLVMDialect pipeline\n"; + context.disableMultithreading(); + pm.enableIRPrinting(); + pm.enableStatistics(); + pm.enableTiming(); + pm.enableVerifier(); + } // Bufferize addFilteredPassToPassManager(pm, mlir::createTensorConstantBufferizePass(), - enablePass); - addFilteredPassToPassManager(pm, mlir::createStdBufferizePass(), enablePass); + options.enablePass); + addFilteredPassToPassManager(pm, mlir::createStdBufferizePass(), + options.enablePass); addFilteredPassToPassManager(pm, mlir::createTensorBufferizePass(), - enablePass); + options.enablePass); addFilteredPassToPassManager(pm, mlir::createLinalgBufferizePass(), - enablePass); + options.enablePass); addFilteredPassToPassManager(pm, mlir::createConvertLinalgToLoopsPass(), - enablePass); - addFilteredPassToPassManager(pm, mlir::createFuncBufferizePass(), enablePass); + options.enablePass); + addFilteredPassToPassManager(pm, mlir::createFuncBufferizePass(), + options.enablePass); addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(), - enablePass); + options.enablePass); + addFilteredPassToPassManager( pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(), - enablePass); + options.enablePass); if (pm.run(module).failed()) { return mlir::failure(); diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 550ff4ae1..ed602ed4b 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -175,10 +175,13 @@ processInputBuffer(mlir::MLIRContext &context, }; // Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit. + mlir::zamalang::CompilerTools::LowerOptions lowerOptions; + lowerOptions.enablePass = enablePass; + lowerOptions.verbose = cmdline::verbose; + mlir::zamalang::V0FHEContext fheContext; - LOG_VERBOSE("### Lower from HLFHE to MLIR standards \n"); if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( - context, *module, fheContext, enablePass) + context, *module, fheContext, lowerOptions) .failed()) { return mlir::failure(); } @@ -217,9 +220,8 @@ processInputBuffer(mlir::MLIRContext &context, } // Lower to MLIR LLVM Dialect - LOG_VERBOSE("### Lower from MLIR standards to LLVM\n"); if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( - context, *module, enablePass) + context, *module, lowerOptions) .failed()) { return mlir::failure(); }