mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler): Use named constant for the default pattern rewriting benefit
This introduces a new header file `zamalang/Support/Constants.h` for constants, currently only populated with a constant for the default pattern rewriting benefit of 1.
This commit is contained in:
10
compiler/include/zamalang/Support/Constants.h
Normal file
10
compiler/include/zamalang/Support/Constants.h
Normal file
@@ -0,0 +1,10 @@
|
||||
#ifndef ZAMALANG_SUPPORT_CONSTANTS_H_
|
||||
#define ZAMALANG_SUPPORT_CONSTANTS_H_
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
constexpr unsigned DEFAULT_PATTERN_BENEFIT = 1;
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -15,12 +15,13 @@
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
|
||||
#include "zamalang/Support/Constants.h"
|
||||
|
||||
struct DotToLinalgGeneric
|
||||
: public ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::Dot> {
|
||||
DotToLinalgGeneric(::mlir::MLIRContext *context)
|
||||
: ::mlir::OpRewritePattern<::mlir::zamalang::HLFHELinalg::Dot>(context,
|
||||
1) {}
|
||||
: ::mlir::OpRewritePattern<::mlir::zamalang::HLFHELinalg::Dot>(
|
||||
context, mlir::zamalang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `HLFHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
|
||||
@@ -216,8 +217,9 @@ getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType,
|
||||
template <typename HLFHELinalgOp, typename HLFHEOp>
|
||||
struct HLFHELinalgOpToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<HLFHELinalgOp> {
|
||||
HLFHELinalgOpToLinalgGeneric(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
HLFHELinalgOpToLinalgGeneric(
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<HLFHELinalgOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
@@ -322,7 +324,8 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp> {
|
||||
HLFHELinalgApplyMultiLookupTableToLinalgGeneric(
|
||||
::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<
|
||||
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>(context,
|
||||
benefit) {
|
||||
@@ -441,8 +444,9 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
struct HLFHELinalgApplyLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp> {
|
||||
HLFHELinalgApplyLookupTableToLinalgGeneric(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
HLFHELinalgApplyLookupTableToLinalgGeneric(
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<
|
||||
mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp>(context,
|
||||
benefit) {}
|
||||
@@ -538,8 +542,9 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric
|
||||
//
|
||||
struct HLFHELinalgNegEintToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::NegEintOp> {
|
||||
HLFHELinalgNegEintToLinalgGeneric(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
HLFHELinalgNegEintToLinalgGeneric(
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::NegEintOp>(
|
||||
context, benefit) {}
|
||||
|
||||
@@ -649,7 +654,7 @@ struct HLFHELinalgMatmulToLinalgGeneric
|
||||
mlir::OpBuilder &, mlir::Location, mlir::Type, mlir::Value,
|
||||
mlir::Value)>
|
||||
createMulOp,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<HLFHELinalgMatmulOp>(context, benefit),
|
||||
createMulOp(createMulOp) {}
|
||||
|
||||
@@ -754,8 +759,9 @@ private:
|
||||
//
|
||||
struct HLFHELinalgZeroToLinalgGenerate
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::ZeroOp> {
|
||||
HLFHELinalgZeroToLinalgGenerate(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
HLFHELinalgZeroToLinalgGenerate(
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::ZeroOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Support/Constants.h"
|
||||
|
||||
class LowLFHEToConcreteCAPITypeConverter : public mlir::TypeConverter {
|
||||
|
||||
@@ -333,10 +334,10 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
/// ```
|
||||
template <typename Op>
|
||||
struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
LowLFHEOpToConcreteCAPICallPattern(mlir::MLIRContext *context,
|
||||
mlir::StringRef funcName,
|
||||
mlir::StringRef allocName,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
LowLFHEOpToConcreteCAPICallPattern(
|
||||
mlir::MLIRContext *context, mlir::StringRef funcName,
|
||||
mlir::StringRef allocName,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<Op>(context, benefit), funcName(funcName),
|
||||
allocName(allocName) {}
|
||||
|
||||
@@ -400,8 +401,9 @@ private:
|
||||
|
||||
struct LowLFHEZeroOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp> {
|
||||
LowLFHEZeroOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
LowLFHEZeroOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::ZeroLWEOp>(context,
|
||||
benefit) {}
|
||||
|
||||
@@ -435,8 +437,9 @@ struct LowLFHEZeroOpPattern
|
||||
|
||||
struct LowLFHEEncodeIntOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::EncodeIntOp> {
|
||||
LowLFHEEncodeIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
LowLFHEEncodeIntOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::EncodeIntOp>(context,
|
||||
benefit) {}
|
||||
|
||||
@@ -459,8 +462,9 @@ struct LowLFHEEncodeIntOpPattern
|
||||
|
||||
struct LowLFHEIntToCleartextOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::IntToCleartextOp> {
|
||||
LowLFHEIntToCleartextOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
LowLFHEIntToCleartextOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::IntToCleartextOp>(
|
||||
context, benefit) {}
|
||||
|
||||
@@ -483,8 +487,9 @@ struct LowLFHEIntToCleartextOpPattern
|
||||
// allocated GLWE
|
||||
struct GlweFromTableOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::GlweFromTable> {
|
||||
GlweFromTableOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
GlweFromTableOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::GlweFromTable>(
|
||||
context, benefit) {}
|
||||
|
||||
@@ -600,8 +605,9 @@ mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
// ciphertext
|
||||
struct LowLFHEBootstrapLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp> {
|
||||
LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
LowLFHEBootstrapLweOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp>(
|
||||
context, benefit) {}
|
||||
|
||||
@@ -663,8 +669,9 @@ struct LowLFHEBootstrapLweOpPattern
|
||||
// - use the key to keyswitch the input ciphertext
|
||||
struct LowLFHEKeySwitchLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp> {
|
||||
LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
LowLFHEKeySwitchLweOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
|
||||
context, benefit) {}
|
||||
|
||||
@@ -743,8 +750,9 @@ void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
|
||||
struct AddRuntimeContextToFuncOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::FuncOp> {
|
||||
AddRuntimeContextToFuncOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
AddRuntimeContextToFuncOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::FuncOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
@@ -869,4 +877,4 @@ createConvertLowLFHEToConcreteCAPIPass() {
|
||||
return std::make_unique<LowLFHEToConcreteCAPIPass>();
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Support/Constants.h"
|
||||
|
||||
/// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize
|
||||
/// LowLFHE types
|
||||
@@ -46,8 +47,9 @@ public:
|
||||
/// t1 are a LowLFHE type.
|
||||
struct LowLFHEUnrealizedCastReplacementPattern
|
||||
: public mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp> {
|
||||
LowLFHEUnrealizedCastReplacementPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
LowLFHEUnrealizedCastReplacementPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp>(context,
|
||||
benefit) {}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/Constants.h"
|
||||
|
||||
namespace {
|
||||
struct MidLFHEGlobalParametrizationPass
|
||||
@@ -60,9 +61,9 @@ public:
|
||||
|
||||
template <typename Op>
|
||||
struct MidLFHEOpTypeConversionPattern : public mlir::OpRewritePattern<Op> {
|
||||
MidLFHEOpTypeConversionPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &typeConverter,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
MidLFHEOpTypeConversionPattern(
|
||||
mlir::MLIRContext *context, mlir::TypeConverter &typeConverter,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<Op>(context, benefit),
|
||||
typeConverter(typeConverter) {}
|
||||
|
||||
@@ -86,7 +87,7 @@ struct MidLFHEApplyLookupTableParametrizationPattern
|
||||
MidLFHEApplyLookupTableParametrizationPattern(
|
||||
mlir::MLIRContext *context, mlir::TypeConverter &typeConverter,
|
||||
mlir::zamalang::V0Parameter &v0Parameter,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable>(
|
||||
context, benefit),
|
||||
typeConverter(typeConverter), v0Parameter(v0Parameter) {}
|
||||
@@ -133,8 +134,9 @@ private:
|
||||
|
||||
struct MidLFHEApplyLookupTablePaddingPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable> {
|
||||
MidLFHEApplyLookupTablePaddingPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
MidLFHEApplyLookupTablePaddingPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::zamalang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::zamalang::MidLFHE::ApplyLookupTable>(
|
||||
context, benefit),
|
||||
typeConverter(typeConverter), v0Parameter(v0Parameter) {}
|
||||
|
||||
Reference in New Issue
Block a user