From 27ca5122bcfe1e4f5fb020a586a61bf92e55889c Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 17 Dec 2021 15:25:45 +0100 Subject: [PATCH] enhance(compiler): Use named constant for the default pattern rewriting benefit This introduces a new header file `zamalang/Support/Constants.h` for constants, currently only populated with a constant for the default pattern rewriting benefit of 1. --- compiler/include/zamalang/Support/Constants.h | 10 ++++ .../TensorOpsToLinalg.cpp | 30 +++++++----- .../LowLFHEToConcreteCAPI.cpp | 46 +++++++++++-------- .../LowLFHEUnparametrize.cpp | 6 ++- .../MidLFHEGlobalParametrization.cpp | 14 +++--- 5 files changed, 67 insertions(+), 39 deletions(-) create mode 100644 compiler/include/zamalang/Support/Constants.h diff --git a/compiler/include/zamalang/Support/Constants.h b/compiler/include/zamalang/Support/Constants.h new file mode 100644 index 000000000..554e78598 --- /dev/null +++ b/compiler/include/zamalang/Support/Constants.h @@ -0,0 +1,10 @@ +#ifndef ZAMALANG_SUPPORT_CONSTANTS_H_ +#define ZAMALANG_SUPPORT_CONSTANTS_H_ + +namespace mlir { +namespace zamalang { +constexpr unsigned DEFAULT_PATTERN_BENEFIT = 1; +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 0fb1d8433..b1103504a 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -15,12 +15,13 @@ #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" #include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h" #include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h" +#include "zamalang/Support/Constants.h" struct DotToLinalgGeneric : public ::mlir::OpRewritePattern { DotToLinalgGeneric(::mlir::MLIRContext *context) - : ::mlir::OpRewritePattern<::mlir::zamalang::HLFHELinalg::Dot>(context, - 1) {} + : ::mlir::OpRewritePattern<::mlir::zamalang::HLFHELinalg::Dot>( + context, mlir::zamalang::DEFAULT_PATTERN_BENEFIT) {} // This rewrite pattern transforms any instance of // `HLFHELinalg.dot_eint_int` to an instance of `linalg.generic` with an @@ -216,8 +217,9 @@ getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType, template struct HLFHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern { - HLFHELinalgOpToLinalgGeneric(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + HLFHELinalgOpToLinalgGeneric( + ::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult @@ -322,7 +324,8 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric : public mlir::OpRewritePattern< mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp> { HLFHELinalgApplyMultiLookupTableToLinalgGeneric( - ::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + ::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern< mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>(context, benefit) { @@ -441,8 +444,9 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric struct HLFHELinalgApplyLookupTableToLinalgGeneric : public mlir::OpRewritePattern< mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp> { - HLFHELinalgApplyLookupTableToLinalgGeneric(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + HLFHELinalgApplyLookupTableToLinalgGeneric( + ::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern< mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp>(context, benefit) {} @@ -538,8 +542,9 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric // struct HLFHELinalgNegEintToLinalgGeneric : public mlir::OpRewritePattern { - HLFHELinalgNegEintToLinalgGeneric(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + HLFHELinalgNegEintToLinalgGeneric( + ::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern( context, benefit) {} @@ -649,7 +654,7 @@ struct HLFHELinalgMatmulToLinalgGeneric mlir::OpBuilder &, mlir::Location, mlir::Type, mlir::Value, mlir::Value)> createMulOp, - mlir::PatternBenefit benefit = 1) + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern(context, benefit), createMulOp(createMulOp) {} @@ -754,8 +759,9 @@ private: // struct HLFHELinalgZeroToLinalgGenerate : public mlir::OpRewritePattern { - HLFHELinalgZeroToLinalgGenerate(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + HLFHELinalgZeroToLinalgGenerate( + ::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : ::mlir::OpRewritePattern(context, benefit) { } diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp index e4a6e1e4c..eea3a0551 100644 --- a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -10,6 +10,7 @@ #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" +#include "zamalang/Support/Constants.h" class LowLFHEToConcreteCAPITypeConverter : public mlir::TypeConverter { @@ -333,10 +334,10 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op, /// ``` template struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern { - LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context, - mlir::StringRef funcName, - mlir::StringRef allocName, - mlir::PatternBenefit benefit = 1) + LowLFHEOpToConcreteCAPICallPattern( + mlir::MLIRContext *context, mlir::StringRef funcName, + mlir::StringRef allocName, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit), funcName(funcName), allocName(allocName) {} @@ -400,8 +401,9 @@ private: struct LowLFHEZeroOpPattern : public mlir::OpRewritePattern { - LowLFHEZeroOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + LowLFHEZeroOpPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit) {} @@ -435,8 +437,9 @@ struct LowLFHEZeroOpPattern struct LowLFHEEncodeIntOpPattern : public mlir::OpRewritePattern { - LowLFHEEncodeIntOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + LowLFHEEncodeIntOpPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit) {} @@ -459,8 +462,9 @@ struct LowLFHEEncodeIntOpPattern struct LowLFHEIntToCleartextOpPattern : public mlir::OpRewritePattern { - LowLFHEIntToCleartextOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + LowLFHEIntToCleartextOpPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern( context, benefit) {} @@ -483,8 +487,9 @@ struct LowLFHEIntToCleartextOpPattern // allocated GLWE struct GlweFromTableOpPattern : public mlir::OpRewritePattern { - GlweFromTableOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + GlweFromTableOpPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern( context, benefit) {} @@ -600,8 +605,9 @@ mlir::Value getContextArgument(mlir::Operation *op) { // ciphertext struct LowLFHEBootstrapLweOpPattern : public mlir::OpRewritePattern { - LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + LowLFHEBootstrapLweOpPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern( context, benefit) {} @@ -663,8 +669,9 @@ struct LowLFHEBootstrapLweOpPattern // - use the key to keyswitch the input ciphertext struct LowLFHEKeySwitchLweOpPattern : public mlir::OpRewritePattern { - LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + LowLFHEKeySwitchLweOpPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern( context, benefit) {} @@ -743,8 +750,9 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) { struct AddRuntimeContextToFuncOpPattern : public mlir::OpRewritePattern { - AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + AddRuntimeContextToFuncOpPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit) {} mlir::LogicalResult @@ -869,4 +877,4 @@ createConvertLowLFHEToConcreteCAPIPass() { return std::make_unique(); } } // namespace zamalang -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp index 4fb6638f0..cb8d63899 100644 --- a/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp +++ b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp @@ -7,6 +7,7 @@ #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" +#include "zamalang/Support/Constants.h" /// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize /// LowLFHE types @@ -46,8 +47,9 @@ public: /// t1 are a LowLFHE type. struct LowLFHEUnrealizedCastReplacementPattern : public mlir::OpRewritePattern { - LowLFHEUnrealizedCastReplacementPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + LowLFHEUnrealizedCastReplacementPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit) {} diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index 86f22a7d2..c203c67cc 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -7,6 +7,7 @@ #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" +#include "zamalang/Support/Constants.h" namespace { struct MidLFHEGlobalParametrizationPass @@ -60,9 +61,9 @@ public: template struct MidLFHEOpTypeConversionPattern : public mlir::OpRewritePattern { - MidLFHEOpTypeConversionPattern(mlir::MLIRContext *context, - mlir::TypeConverter &typeConverter, - mlir::PatternBenefit benefit = 1) + MidLFHEOpTypeConversionPattern( + mlir::MLIRContext *context, mlir::TypeConverter &typeConverter, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern(context, benefit), typeConverter(typeConverter) {} @@ -86,7 +87,7 @@ struct MidLFHEApplyLookupTableParametrizationPattern MidLFHEApplyLookupTableParametrizationPattern( mlir::MLIRContext *context, mlir::TypeConverter &typeConverter, mlir::zamalang::V0Parameter &v0Parameter, - mlir::PatternBenefit benefit = 1) + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern( context, benefit), typeConverter(typeConverter), v0Parameter(v0Parameter) {} @@ -133,8 +134,9 @@ private: struct MidLFHEApplyLookupTablePaddingPattern : public mlir::OpRewritePattern { - MidLFHEApplyLookupTablePaddingPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + MidLFHEApplyLookupTablePaddingPattern( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT) : mlir::OpRewritePattern( context, benefit), typeConverter(typeConverter), v0Parameter(v0Parameter) {}