diff --git a/compiler/include/zamalang/Conversion/MidLFHEGlobalParametrization/Pass.h b/compiler/include/zamalang/Conversion/MidLFHEGlobalParametrization/Pass.h new file mode 100644 index 000000000..d2435751c --- /dev/null +++ b/compiler/include/zamalang/Conversion/MidLFHEGlobalParametrization/Pass.h @@ -0,0 +1,18 @@ + +#ifndef ZAMALANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_ +#define ZAMALANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_ + +#include "mlir/Pass/Pass.h" + +#include "zamalang/Conversion/Utils/GlobalFHEContext.h" + +namespace mlir { +namespace zamalang { +/// Create a pass to inject fhe parameters to the MidLFHE types and operators. +std::unique_ptr> +createConvertMidLFHEGlobalParametrizationPass( + mlir::zamalang::V0FHEContext &fheContext); +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Conversion/Passes.h b/compiler/include/zamalang/Conversion/Passes.h index 822ffd108..524fc8985 100644 --- a/compiler/include/zamalang/Conversion/Passes.h +++ b/compiler/include/zamalang/Conversion/Passes.h @@ -10,9 +10,11 @@ #include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h" #include "zamalang/Conversion/LowLFHEToConcreteCAPI/Pass.h" #include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h" +#include "zamalang/Conversion/MidLFHEGlobalParametrization/Pass.h" #include "zamalang/Conversion/MidLFHEToLowLFHE/Pass.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #define GEN_PASS_CLASSES #include "zamalang/Conversion/Passes.h.inc" diff --git a/compiler/include/zamalang/Conversion/Passes.td b/compiler/include/zamalang/Conversion/Passes.td index 1781abcf6..f4c7f6cd5 100644 --- a/compiler/include/zamalang/Conversion/Passes.td +++ b/compiler/include/zamalang/Conversion/Passes.td @@ -17,6 +17,13 @@ def HLFHEToMidLFHE : Pass<"hlfhe-to-midlfhe", "mlir::ModuleOp"> { let dependentDialects = ["mlir::linalg::LinalgDialect"]; } +def MidLFHEGlobalParametrization : Pass<"midlfhe-global-parametrization", "mlir::ModuleOp"> { + let summary = "Inject global fhe parameters to the MidLFHE dialect"; + let constructor = "mlir::zamalang::createConvertMidLFHEToLowLFHEPass()"; + let options = []; + let dependentDialects = ["mlir::zamalang::MidLFHE::MidLFHEDialect"]; +} + def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> { let summary = "Lowers operations from the MidLFHE dialect to LowLFHE"; let description = [{ Lowers operations from the MidLFHE dialect to LowLFHE }]; diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index 070482fac..c6646bdc9 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(HLFHEToMidLFHE) +add_subdirectory(MidLFHEGlobalParametrization) add_subdirectory(MidLFHEToLowLFHE) add_subdirectory(HLFHETensorOpsToLinalg) add_subdirectory(LowLFHEToConcreteCAPI) diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/CMakeLists.txt b/compiler/lib/Conversion/MidLFHEGlobalParametrization/CMakeLists.txt new file mode 100644 index 000000000..1094979d8 --- /dev/null +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_dialect_library(MidLFHEGlobalParametrization +MidLFHEGlobalParametrization.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/MidLFHE + + DEPENDS + MidLFHEDialect + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransforms +) + +target_link_libraries(MidLFHEGlobalParametrization PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp new file mode 100644 index 000000000..bb721c20a --- /dev/null +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -0,0 +1,198 @@ +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "zamalang/Conversion/Passes.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" + +namespace { +struct MidLFHEGlobalParametrizationPass + : public MidLFHEGlobalParametrizationBase< + MidLFHEGlobalParametrizationPass> { + MidLFHEGlobalParametrizationPass(mlir::zamalang::V0FHEContext &fheContext) + : fheContext(fheContext){}; + void runOnOperation() final; + mlir::zamalang::V0FHEContext &fheContext; +}; +} // namespace + +using mlir::zamalang::MidLFHE::GLWECipherTextType; + +/// MidLFHEGlobalParametrizationTypeConverter is a TypeConverter that transform +/// `MidLFHE.gwle<{_,_,_}{p}>` to +/// `MidLFHE.gwle<{glweSize,polynomialSize,bits}{p'}>` +class MidLFHEGlobalParametrizationTypeConverter : public mlir::TypeConverter { + +public: + MidLFHEGlobalParametrizationTypeConverter( + mlir::zamalang::V0FHEContext &fheContext) { + addConversion([](mlir::Type type) { return type; }); + addConversion([&](GLWECipherTextType type) { + auto glweSize = fheContext.parameter.getNBigGlweSize(); + auto p = fheContext.constraint.p; + if (type.getDimension() == glweSize && type.getP() == p) { + return type; + } + return GLWECipherTextType::get(type.getContext(), glweSize, + 1 /*for the v0, is always lwe ciphertext*/, + 64 /*for the v0 we handle only q=64*/, p); + }); + } +}; + +template +struct MidLFHEOpTypeConversionPattern : public mlir::OpRewritePattern { + MidLFHEOpTypeConversionPattern(mlir::MLIRContext *context, + mlir::TypeConverter &typeConverter, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit), + typeConverter(typeConverter) {} + + mlir::LogicalResult + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { + mlir::SmallVector newResultTypes; + if (typeConverter.convertTypes(op->getResultTypes(), newResultTypes) + .failed()) { + return mlir::failure(); + } + rewriter.replaceOpWithNewOp(op, newResultTypes, op->getOperands()); + return mlir::success(); + }; + +private: + mlir::TypeConverter &typeConverter; +}; + +struct MidLFHEApplyLookupTableParametrizationPattern + : public mlir::OpRewritePattern { + MidLFHEApplyLookupTableParametrizationPattern( + mlir::MLIRContext *context, mlir::TypeConverter &typeConverter, + mlir::zamalang::V0Parameter &v0Parameter, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern( + context, benefit), + typeConverter(typeConverter), v0Parameter(v0Parameter) {} + + mlir::LogicalResult + matchAndRewrite(mlir::zamalang::MidLFHE::ApplyLookupTable op, + mlir::PatternRewriter &rewriter) const override { + mlir::SmallVector newResultTypes; + if (typeConverter.convertTypes(op->getResultTypes(), newResultTypes) + .failed()) { + return mlir::failure(); + } + + mlir::SmallVector newAttributes{ + mlir::NamedAttribute(rewriter.getIdentifier("k"), + rewriter.getI32IntegerAttr(v0Parameter.k)), + mlir::NamedAttribute( + rewriter.getIdentifier("polynomialSize"), + rewriter.getI32IntegerAttr(v0Parameter.polynomialSize)), + mlir::NamedAttribute(rewriter.getIdentifier("levelKS"), + rewriter.getI32IntegerAttr(v0Parameter.ksLevel)), + mlir::NamedAttribute(rewriter.getIdentifier("baseLogKS"), + rewriter.getI32IntegerAttr(v0Parameter.ksLogBase)), + mlir::NamedAttribute(rewriter.getIdentifier("levelBS"), + rewriter.getI32IntegerAttr(v0Parameter.brLevel)), + mlir::NamedAttribute(rewriter.getIdentifier("baseLogBS"), + rewriter.getI32IntegerAttr(v0Parameter.brLogBase)), + }; + + rewriter.replaceOpWithNewOp( + op, newResultTypes, op->getOperands(), newAttributes); + + return mlir::success(); + }; + +private: + mlir::TypeConverter &typeConverter; + mlir::zamalang::V0Parameter &v0Parameter; +}; + +template +void populateWithMidLFHEOpTypeConversionPattern( + mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, + mlir::TypeConverter &typeConverter) { + patterns.add>(patterns.getContext(), + typeConverter); + target.addDynamicallyLegalOp( + [&](Op op) { return typeConverter.isLegal(op->getResultTypes()); }); +} + +void populateWithMidLFHEApplyLookupTableParametrizationPattern( + mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, + mlir::TypeConverter &typeConverter, + mlir::zamalang::V0Parameter &v0Parameter) { + patterns.add( + patterns.getContext(), typeConverter, v0Parameter); + target.addDynamicallyLegalOp( + [&](mlir::zamalang::MidLFHE::ApplyLookupTable op) { + if (op.k() != v0Parameter.k || + op.polynomialSize() != v0Parameter.polynomialSize || + op.levelKS() != v0Parameter.ksLevel || + op.baseLogKS() != v0Parameter.ksLogBase || + op.levelBS() != v0Parameter.brLevel || + op.baseLogBS() != v0Parameter.brLogBase) { + return false; + } + return typeConverter.isLegal(op->getResultTypes()); + }); +} + +/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE +/// operators to the corresponding function call to the `Concrete C API`. +void populateWithMidLFHEOpTypeConversionPatterns( + mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, + mlir::TypeConverter &typeConverter, + mlir::zamalang::V0Parameter &v0Parameter) { + populateWithMidLFHEOpTypeConversionPattern< + mlir::zamalang::MidLFHE::AddGLWEIntOp>(patterns, target, typeConverter); + populateWithMidLFHEOpTypeConversionPattern< + mlir::zamalang::MidLFHE::AddGLWEOp>(patterns, target, typeConverter); + populateWithMidLFHEOpTypeConversionPattern< + mlir::zamalang::MidLFHE::SubIntGLWEOp>(patterns, target, typeConverter); + populateWithMidLFHEOpTypeConversionPattern< + mlir::zamalang::MidLFHE::MulGLWEIntOp>(patterns, target, typeConverter); + populateWithMidLFHEApplyLookupTableParametrizationPattern( + patterns, target, typeConverter, v0Parameter); +} + +void MidLFHEGlobalParametrizationPass::runOnOperation() { + auto op = this->getOperation(); + + mlir::ConversionTarget target(getContext()); + MidLFHEGlobalParametrizationTypeConverter converter(fheContext); + + // Make sure that no ops from `MidLFHE` remain after the lowering + target.addIllegalDialect(); + + // Make sure func has legal signature + target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { + return converter.isSignatureLegal(funcOp.getType()) && + converter.isLegal(&funcOp.getBody()); + }); + // Add all patterns required to lower all ops from `MidLFHE` to + // `LowLFHE` + mlir::OwningRewritePatternList patterns(&getContext()); + populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter, + fheContext.parameter); + mlir::populateFuncOpTypeConversionPattern(patterns, converter); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { + this->signalPassFailure(); + } +} + +namespace mlir { +namespace zamalang { +std::unique_ptr> +createConvertMidLFHEGlobalParametrizationPass( + mlir::zamalang::V0FHEContext &fheContext) { + return std::make_unique(fheContext); +} +} // namespace zamalang +} // namespace mlir diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index 9334b1186..f7d5853b8 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -44,7 +44,9 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( addFilteredPassToPassManager( pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); + pm, + mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(fheContext), + enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass); addFilteredPassToPassManager( diff --git a/compiler/tests/RunJit/hlfhe_add_eint_int_6.mlir b/compiler/tests/RunJit/hlfhe_add_eint_int_6.mlir new file mode 100644 index 000000000..88f86fc7c --- /dev/null +++ b/compiler/tests/RunJit/hlfhe_add_eint_int_6.mlir @@ -0,0 +1,7 @@ +// RUN: zamacompiler %s --run-jit --jit-args 10 --jit-args 54 2>&1| FileCheck %s + +// CHECK-LABEL: 64 +func @main(%arg0: !HLFHE.eint<6>, %arg1: i7) -> !HLFHE.eint<6> { + %1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<6>, i7) -> (!HLFHE.eint<6>) + return %1: !HLFHE.eint<6> +} \ No newline at end of file diff --git a/compiler/tests/RunJit/hlfhe_add_eint_int.mlir b/compiler/tests/RunJit/hlfhe_add_eint_int_7.mlir similarity index 63% rename from compiler/tests/RunJit/hlfhe_add_eint_int.mlir rename to compiler/tests/RunJit/hlfhe_add_eint_int_7.mlir index 0898b77bd..719302f5b 100644 --- a/compiler/tests/RunJit/hlfhe_add_eint_int.mlir +++ b/compiler/tests/RunJit/hlfhe_add_eint_int_7.mlir @@ -1,6 +1,6 @@ -// RUN: zamacompiler %s --run-jit --jit-args 12 --jit-args 30 2>&1| FileCheck %s +// RUN: zamacompiler %s --run-jit --jit-args 100 --jit-args 27 2>&1| FileCheck %s -// CHECK-LABEL: 42 +// CHECK-LABEL: 127 func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> { %1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7>