enhance(compiler): Use named constant for the default pattern rewriting benefit

This introduces a new header file `zamalang/Support/Constants.h` for
constants, currently only populated with a constant for the default
pattern rewriting benefit of 1.
This commit is contained in:
Andi Drebes
2021-12-17 15:25:45 +01:00
parent 2e5bff93fd
commit 27ca5122bc
5 changed files with 67 additions and 39 deletions

View File

@@ -0,0 +1,10 @@
#ifndef ZAMALANG_SUPPORT_CONSTANTS_H_
#define ZAMALANG_SUPPORT_CONSTANTS_H_
namespace mlir {
namespace zamalang {
constexpr unsigned DEFAULT_PATTERN_BENEFIT = 1;
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -15,12 +15,13 @@
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
#include "zamalang/Support/Constants.h"
struct DotToLinalgGeneric
: public ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::Dot> {
DotToLinalgGeneric(::mlir::MLIRContext *context)
: ::mlir::OpRewritePattern<::mlir::zamalang::HLFHELinalg::Dot>(context,
1) {}
: ::mlir::OpRewritePattern<::mlir::zamalang::HLFHELinalg::Dot>(
context, mlir::zamalang::DEFAULT_PATTERN_BENEFIT) {}
// This rewrite pattern transforms any instance of
// `HLFHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
@@ -216,8 +217,9 @@ getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType,
template <typename HLFHELinalgOp, typename HLFHEOp>
struct HLFHELinalgOpToLinalgGeneric
: public mlir::OpRewritePattern<HLFHELinalgOp> {
HLFHELinalgOpToLinalgGeneric(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
HLFHELinalgOpToLinalgGeneric(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<HLFHELinalgOp>(context, benefit) {}
::mlir::LogicalResult
@@ -322,7 +324,8 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
: public mlir::OpRewritePattern<
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp> {
HLFHELinalgApplyMultiLookupTableToLinalgGeneric(
::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>(context,
benefit) {
@@ -441,8 +444,9 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
struct HLFHELinalgApplyLookupTableToLinalgGeneric
: public mlir::OpRewritePattern<
mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp> {
HLFHELinalgApplyLookupTableToLinalgGeneric(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
HLFHELinalgApplyLookupTableToLinalgGeneric(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<
mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp>(context,
benefit) {}
@@ -538,8 +542,9 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric
//
struct HLFHELinalgNegEintToLinalgGeneric
: public mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::NegEintOp> {
HLFHELinalgNegEintToLinalgGeneric(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
HLFHELinalgNegEintToLinalgGeneric(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::NegEintOp>(
context, benefit) {}
@@ -649,7 +654,7 @@ struct HLFHELinalgMatmulToLinalgGeneric
mlir::OpBuilder &, mlir::Location, mlir::Type, mlir::Value,
mlir::Value)>
createMulOp,
mlir::PatternBenefit benefit = 1)
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<HLFHELinalgMatmulOp>(context, benefit),
createMulOp(createMulOp) {}
@@ -754,8 +759,9 @@ private:
//
struct HLFHELinalgZeroToLinalgGenerate
: public mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::ZeroOp> {
HLFHELinalgZeroToLinalgGenerate(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
HLFHELinalgZeroToLinalgGenerate(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::ZeroOp>(context,
benefit) {
}

View File

@@ -10,6 +10,7 @@
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Support/Constants.h"
class LowLFHEToConcreteCAPITypeConverter : public mlir::TypeConverter {
@@ -333,10 +334,10 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
/// ```
template <typename Op>
struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
mlir::StringRef funcName,
mlir::StringRef allocName,
mlir::PatternBenefit benefit = 1)
LowLFHEOpToConcreteCAPICallPattern(
mlir::MLIRContext *context, mlir::StringRef funcName,
mlir::StringRef allocName,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<Op>(context, benefit), funcName(funcName),
allocName(allocName) {}
@@ -400,8 +401,9 @@ private:
struct LowLFHEZeroOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp> {
LowLFHEZeroOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
LowLFHEZeroOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp>(context,
benefit) {}
@@ -435,8 +437,9 @@ struct LowLFHEZeroOpPattern
struct LowLFHEEncodeIntOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::EncodeIntOp> {
LowLFHEEncodeIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
LowLFHEEncodeIntOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::EncodeIntOp>(context,
benefit) {}
@@ -459,8 +462,9 @@ struct LowLFHEEncodeIntOpPattern
struct LowLFHEIntToCleartextOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::IntToCleartextOp> {
LowLFHEIntToCleartextOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
LowLFHEIntToCleartextOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::IntToCleartextOp>(
context, benefit) {}
@@ -483,8 +487,9 @@ struct LowLFHEIntToCleartextOpPattern
// allocated GLWE
struct GlweFromTableOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::GlweFromTable> {
GlweFromTableOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
GlweFromTableOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::GlweFromTable>(
context, benefit) {}
@@ -600,8 +605,9 @@ mlir::Value getContextArgument(mlir::Operation *op) {
// ciphertext
struct LowLFHEBootstrapLweOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp> {
LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
LowLFHEBootstrapLweOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp>(
context, benefit) {}
@@ -663,8 +669,9 @@ struct LowLFHEBootstrapLweOpPattern
// - use the key to keyswitch the input ciphertext
struct LowLFHEKeySwitchLweOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp> {
LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
LowLFHEKeySwitchLweOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
context, benefit) {}
@@ -743,8 +750,9 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
struct AddRuntimeContextToFuncOpPattern
: public mlir::OpRewritePattern<mlir::FuncOp> {
AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
AddRuntimeContextToFuncOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::FuncOp>(context, benefit) {}
mlir::LogicalResult
@@ -869,4 +877,4 @@ createConvertLowLFHEToConcreteCAPIPass() {
return std::make_unique<LowLFHEToConcreteCAPIPass>();
}
} // namespace zamalang
} // namespace mlir
} // namespace mlir

View File

@@ -7,6 +7,7 @@
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Support/Constants.h"
/// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize
/// LowLFHE types
@@ -46,8 +47,9 @@ public:
/// t1 are a LowLFHE type.
struct LowLFHEUnrealizedCastReplacementPattern
: public mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp> {
LowLFHEUnrealizedCastReplacementPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
LowLFHEUnrealizedCastReplacementPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp>(context,
benefit) {}

View File

@@ -7,6 +7,7 @@
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/Constants.h"
namespace {
struct MidLFHEGlobalParametrizationPass
@@ -60,9 +61,9 @@ public:
template <typename Op>
struct MidLFHEOpTypeConversionPattern : public mlir::OpRewritePattern<Op> {
MidLFHEOpTypeConversionPattern(mlir::MLIRContext *context,
mlir::TypeConverter &typeConverter,
mlir::PatternBenefit benefit = 1)
MidLFHEOpTypeConversionPattern(
mlir::MLIRContext *context, mlir::TypeConverter &typeConverter,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<Op>(context, benefit),
typeConverter(typeConverter) {}
@@ -86,7 +87,7 @@ struct MidLFHEApplyLookupTableParametrizationPattern
MidLFHEApplyLookupTableParametrizationPattern(
mlir::MLIRContext *context, mlir::TypeConverter &typeConverter,
mlir::zamalang::V0Parameter &v0Parameter,
mlir::PatternBenefit benefit = 1)
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable>(
context, benefit),
typeConverter(typeConverter), v0Parameter(v0Parameter) {}
@@ -133,8 +134,9 @@ private:
struct MidLFHEApplyLookupTablePaddingPattern
: public mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable> {
MidLFHEApplyLookupTablePaddingPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
MidLFHEApplyLookupTablePaddingPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable>(
context, benefit),
typeConverter(typeConverter), v0Parameter(v0Parameter) {}