refactor(compiler): FHE to TFHE: Use OpConversionPattern for dialect conversion

Use `OpConversionPattern` instead of `OpRewritePattern` for operation
conversion during dialect conversion. This makes explicit and in-place
type conversions unnecessary, since `OpConversionPattern` already
properly converts operand types and provides them to the rewrite rule
through an operation adaptor.

The main contributions of this commit are the two class templates
`TypeConvertingReinstantiationPattern` and
`GenericOneToOneOpConversionPattern`.

The former allows for the definition of a simple replacement rule that
re-instantiates an operation after the types of its operands have been
converted. This is especially useful for type-polymorphic operations
during dialect conversion.

The latter allows for the definition of patterns, where one operation
needs to be replaced with a different operation after conversion of
its operands.

The default implementations for the class templates provide
conversions rules for operations that have a generic builder method
that takes the desired return type(s), the operands and (optionally) a
set of attributes. How attributes are discarded during a conversion
(either by omitting the builder argument or by passing an empty set of
attributes) can be defined through specialization of
`ReinstantiationAttributeDismissalStrategy`.

Custom replacement rules that deviate from the scheme above should be
implemented by specializing
`TypeConvertingReinstantiationPattern::matchAndRewrite()` and
`GenericOneToOneOpConversionPattern::matchAndRewrite()`.
This commit is contained in:
Andi Drebes
2023-01-26 14:54:32 +01:00
parent 49b8bf484c
commit 73fd6c5fe7
19 changed files with 781 additions and 335 deletions

View File

@@ -0,0 +1,29 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_UTILS_DIALECTS_SCF_H_
#define CONCRETELANG_CONVERSION_UTILS_DIALECTS_SCF_H_
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
namespace mlir {
namespace concretelang {
//
// Specializations for ForOp
//
// Specialization copying attributes omitted
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
scf::ForOp oldOp, mlir::OpConversionPattern<scf::ForOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const;
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,70 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_
#define CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
namespace mlir {
namespace concretelang {
//
// Specializations for CollapseShapeOp
//
// Specialization copying attributes not necessary, as the base
// template works correctly
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::CollapseShapeOp, false>::
matchAndRewrite(
tensor::CollapseShapeOp oldOp,
mlir::OpConversionPattern<tensor::CollapseShapeOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const;
//
// Specializations for FromElementsOp
//
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp, false>::
matchAndRewrite(
tensor::FromElementsOp oldOp,
mlir::OpConversionPattern<mlir::tensor::FromElementsOp>::OpAdaptor
adaptor,
mlir::ConversionPatternRewriter &rewriter) const;
//
// Specializations for ExpandShapeOp
//
// Specialization copying attributes not necessary, as the base
// template works correctly
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::ExpandShapeOp, false>::
matchAndRewrite(
tensor::ExpandShapeOp oldOp,
mlir::OpConversionPattern<tensor::ExpandShapeOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const;
//
// Specializations for GenerateOp
//
// Specialization NOT copying attributes omitted
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::GenerateOp, true>::matchAndRewrite(
tensor::GenerateOp oldOp,
mlir::OpConversionPattern<tensor::GenerateOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const;
} // namespace concretelang
} // namespace mlir
#endif // CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_

View File

@@ -96,16 +96,6 @@ struct GenericTypeAndOpConverterPattern : public mlir::OpRewritePattern<OldOp> {
private:
mlir::TypeConverter &converter;
};
template <typename Op>
void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter) {
target.addDynamicallyLegalOp<Op>([&](Op op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,26 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_UTILS_LEGALITY_H_
#define CONCRETELANG_CONVERSION_UTILS_LEGALITY_H_
#include <mlir/Transforms/DialectConversion.h>
namespace mlir {
namespace concretelang {
template <typename Op>
void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter) {
target.addDynamicallyLegalOp<Op>([&](Op op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});
}
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,216 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_UTILS_REINSTANTIATINGOPTYPECONVERSION_H_
#define CONCRETELANG_CONVERSION_UTILS_REINSTANTIATINGOPTYPECONVERSION_H_
#include <mlir/Transforms/DialectConversion.h>
namespace mlir {
namespace concretelang {
// Set of types defining how attributes should be handled when
// invocating the build method of an operation upon reinstantiation
struct ReinstantiationAttributeHandling {
// Copy attributes
struct copy {};
// Completely dismiss attributes by not passing a set of arguments
// to the builder at all
struct dismiss {};
// Dismiss attributes by passing an empty set of arguments to the
// builder
struct pass_empty_vector {};
};
// Template defining how attributes should be dismissed when invoking
// the build method of an operation upon reinstantiation. In the
// default case, the argument for attributes is simply dismissed.
template <typename T> struct ReinstantiationAttributeDismissalStrategy {
typedef ReinstantiationAttributeHandling::dismiss strategy;
};
// Template defining how attributes should be copied when invoking the
// build method of an operation upon reinstantiation. In the default
// case, the argument for attributes is forwarded to the build method.
template <typename T> struct ReinstantiationAttributeCopyStrategy {
typedef ReinstantiationAttributeHandling::copy strategy;
};
namespace {
// Class template that defines the attribute handling strategy for
// either dismissal of attributes (if `copyAttrsSwitch` is `false`) or copying
// attributes (if `copyAttrsSwitch` is `true`).
template <typename T, bool copyAttrsSwitch> struct AttributeHandlingSwitch {};
template <typename T> struct AttributeHandlingSwitch<T, true> {
typedef typename ReinstantiationAttributeCopyStrategy<T>::strategy strategy;
};
template <typename T> struct AttributeHandlingSwitch<T, false> {
typedef
typename ReinstantiationAttributeDismissalStrategy<T>::strategy strategy;
};
// Simple functor-like template invoking a rewriter with a variable
// set of arguments and an op's attributes as the last argument.
template <typename NewOpTy, typename... Args>
struct ReplaceOpWithNewOpCopyAttrs {
static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter,
mlir::Operation *op, mlir::TypeRange resultTypes,
mlir::ValueRange operands) {
return rewriter.replaceOpWithNewOp<NewOpTy>(op, resultTypes, operands,
op->getAttrs());
}
};
// Simple functor-like template invoking a rewriter with a variable
// set of arguments dismissing the attributes passed as the last
// argument.
template <typename NewOpTy, typename... Args>
struct ReplaceOpWithNewOpDismissAttrs {
static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter,
mlir::Operation *op, mlir::TypeRange resultTypes,
mlir::ValueRange operands) {
return rewriter.replaceOpWithNewOp<NewOpTy>(op, resultTypes, operands);
}
};
// Simple functor-like template invoking a rewriter with a variable
// set of arguments dismissing the attributes by passing an empty
// set of arguments to the builder.
template <typename NewOpTy, typename... Args>
struct ReplaceOpWithNewOpEmptyAttrs {
static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter,
mlir::Operation *op, mlir::TypeRange resultTypes,
mlir::ValueRange operands) {
llvm::SmallVector<mlir::NamedAttribute> attrs{};
return rewriter.replaceOpWithNewOp<NewOpTy>(op, resultTypes, operands,
attrs);
}
};
// Functor-like template that either forwards to
// `ReplaceOpWithNewOpCopyAttrs` or `ReplaceOpWithNewOpDismissAttrs`
// depending on the value of `copyAttrs`.
template <typename copyAttrsSwitch, typename OpTy, typename... Args>
struct ReplaceOpWithNewOpAttrSwitch {};
// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does copy
// attributes.
template <typename OpTy, typename... Args>
struct ReplaceOpWithNewOpAttrSwitch<ReinstantiationAttributeHandling::copy,
OpTy, Args...> {
typedef ReplaceOpWithNewOpCopyAttrs<OpTy, Args...> instantiator;
};
// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does NOT copy
// attributes by not passing attributes to the builder at all.
template <typename OpTy, typename... Args>
struct ReplaceOpWithNewOpAttrSwitch<ReinstantiationAttributeHandling::dismiss,
OpTy, Args...> {
typedef ReplaceOpWithNewOpDismissAttrs<OpTy, Args...> instantiator;
};
// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does NOT copy
// attributes by passing an empty set of attributes to the builder.
template <typename OpTy, typename... Args>
struct ReplaceOpWithNewOpAttrSwitch<
ReinstantiationAttributeHandling::pass_empty_vector, OpTy, Args...> {
typedef ReplaceOpWithNewOpEmptyAttrs<OpTy, Args...> instantiator;
};
} // namespace
template <typename OldOp, typename NewOp, bool copyAttrs = false>
struct GenericOneToOneOpConversionPatternBase
: public mlir::OpConversionPattern<OldOp> {
GenericOneToOneOpConversionPatternBase(mlir::MLIRContext *context,
mlir::TypeConverter &converter,
mlir::PatternBenefit benefit = 100)
: mlir::OpConversionPattern<OldOp>(converter, context, benefit) {}
mlir::SmallVector<mlir::Type> convertResultTypes(OldOp oldOp) const {
mlir::TypeConverter *converter = this->getTypeConverter();
// Convert result types
mlir::SmallVector<mlir::Type> resultTypes(oldOp->getNumResults());
for (unsigned i = 0; i < oldOp->getNumResults(); i++) {
auto result = oldOp->getResult(i);
resultTypes[i] = converter->convertType(result.getType());
}
return resultTypes;
}
mlir::Type convertResultType(OldOp oldOp) const {
mlir::TypeConverter *converter = this->getTypeConverter();
return converter->convertType(oldOp->getResult(0).getType());
}
};
// Conversion pattern that replaces an instance of an operation of the type
// `OldOp` with an instance of the type `NewOp`, taking into account operands,
// return types and possible copying attributes (iff copyAttrs is `true`).
template <typename OldOp, typename NewOp, bool copyAttrs = false>
struct GenericOneToOneOpConversionPattern
: public GenericOneToOneOpConversionPatternBase<OldOp, NewOp, copyAttrs> {
GenericOneToOneOpConversionPattern(mlir::MLIRContext *context,
mlir::TypeConverter &converter,
mlir::PatternBenefit benefit = 100)
: GenericOneToOneOpConversionPatternBase<OldOp, NewOp, copyAttrs>(
context, converter, benefit) {}
virtual mlir::LogicalResult
matchAndRewrite(OldOp oldOp,
typename mlir::OpConversionPattern<OldOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::SmallVector<mlir::Type> resultTypes = this->convertResultTypes(oldOp);
ReplaceOpWithNewOpAttrSwitch<
typename AttributeHandlingSwitch<NewOp, copyAttrs>::strategy,
NewOp>::instantiator::replace(rewriter, oldOp,
mlir::TypeRange{resultTypes},
mlir::ValueRange{adaptor.getOperands()});
return mlir::success();
}
};
// Conversion pattern that retrieves the converted operands of an
// operation of the type `Op`, converts the types of the results of
// the operation and re-instantiates the operation type with the
// converted operands and result types.
template <typename Op, bool copyAttrs = false>
struct TypeConvertingReinstantiationPattern
: public GenericOneToOneOpConversionPatternBase<Op, Op, copyAttrs> {
TypeConvertingReinstantiationPattern(mlir::MLIRContext *context,
mlir::TypeConverter &converter,
mlir::PatternBenefit benefit = 100)
: GenericOneToOneOpConversionPatternBase<Op, Op, copyAttrs>(
context, converter, benefit) {}
// Simple forward that makes the method specializable out of class
// directly for this class rather than for its base
virtual mlir::LogicalResult
matchAndRewrite(Op op,
typename mlir::OpConversionPattern<Op>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::SmallVector<mlir::Type> resultTypes = this->convertResultTypes(op);
ReplaceOpWithNewOpAttrSwitch<
typename AttributeHandlingSwitch<Op, copyAttrs>::strategy,
Op>::instantiator::replace(rewriter, op, mlir::TypeRange{resultTypes},
mlir::ValueRange{adaptor.getOperands()});
return mlir::success();
}
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -10,7 +10,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
#include "concretelang/Conversion/Utils/Dialects/Tensor.h"
#include "concretelang/Conversion/Utils/Legality.h"
namespace mlir {
namespace concretelang {
@@ -20,37 +21,43 @@ populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns,
mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter) {
// ExtractOp
patterns.add<GenericTypeConverterPattern<mlir::tensor::ExtractOp>>(
patterns.add<TypeConvertingReinstantiationPattern<mlir::tensor::ExtractOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::ExtractOp>(target, typeConverter);
// ExtractSliceOp
patterns.add<GenericTypeConverterPattern<mlir::tensor::ExtractSliceOp>>(
patterns.add<
TypeConvertingReinstantiationPattern<mlir::tensor::ExtractSliceOp, true>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::ExtractSliceOp>(target,
typeConverter);
// InsertOp
patterns.add<GenericTypeConverterPattern<mlir::tensor::InsertOp>>(
patterns.add<TypeConvertingReinstantiationPattern<mlir::tensor::InsertOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::InsertOp>(target, typeConverter);
// InsertSliceOp
patterns.add<GenericTypeConverterPattern<mlir::tensor::InsertSliceOp>>(
patterns.add<
TypeConvertingReinstantiationPattern<mlir::tensor::InsertSliceOp, true>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::InsertSliceOp>(target, typeConverter);
// FromElementsOp
patterns.add<GenericTypeConverterPattern<mlir::tensor::FromElementsOp>>(
patterns.getContext(), typeConverter);
patterns
.add<TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(target,
typeConverter);
// TensorCollapseShapeOp
patterns.add<GenericTypeConverterPattern<mlir::tensor::CollapseShapeOp>>(
patterns.getContext(), typeConverter);
patterns
.add<TypeConvertingReinstantiationPattern<mlir::tensor::CollapseShapeOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::CollapseShapeOp>(target,
typeConverter);
// TensorExpandShapeOp
patterns.add<GenericTypeConverterPattern<mlir::tensor::ExpandShapeOp>>(
patterns.getContext(), typeConverter);
patterns
.add<TypeConvertingReinstantiationPattern<mlir::tensor::ExpandShapeOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::ExpandShapeOp>(target, typeConverter);
}
} // namespace concretelang

View File

@@ -9,5 +9,13 @@ add_subdirectory(SDFGToStreamEmulator)
add_subdirectory(MLIRLowerableDialectsToLLVM)
add_subdirectory(LinalgExtras)
add_subdirectory(ExtractSDFGOps)
add_subdirectory(Utils)
add_mlir_library(ConcretelangConversion Tools.cpp LINK_LIBS PUBLIC MLIRIR)
add_mlir_library(
ConcretelangConversion
Tools.cpp
Utils/Dialects/SCF.cpp
Utils/Dialects/Tensor.cpp
LINK_LIBS
PUBLIC
MLIRIR)

View File

@@ -27,6 +27,7 @@
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"

View File

@@ -10,13 +10,15 @@
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/Operation.h>
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/FHEToTFHECrt/Pass.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
@@ -112,7 +114,8 @@ namespace lowering {
/// A pattern rewriter superclass used by most op rewriters during the
/// conversion.
template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
template <typename T>
struct CrtOpPattern : public mlir::OpConversionPattern<T> {
/// The lowering parameters are bound to the op rewriter.
concretelang::CrtLoweringParameters loweringParameters;
@@ -120,8 +123,8 @@ template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
CrtOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<T>(context, benefit),
loweringParameters(params) {}
: mlir::OpConversionPattern<T>(typeConverter, context, benefit),
loweringParameters(params), typeConverter(params) {}
/// Writes an `scf::for` that loops over the crt dimension of one tensor and
/// execute the input lambda to write the loop body. Returns the first result
@@ -154,11 +157,6 @@ template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
location, zeroConstantOp, crtSizeConstantOp, oneConstantOp, tensor,
body);
// Convert the types of the new operation
typing::TypeConverter converter(loweringParameters);
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
return newOp.getResult(0);
}
@@ -176,6 +174,9 @@ template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
castedPlaintext, rewriter.getI64ArrayAttr(loweringParameters.mods),
rewriter.getI64IntegerAttr(loweringParameters.modsProd));
}
protected:
typing::TypeConverter typeConverter;
};
/// Rewriter for the `FHE::add_eint_int` operation.
@@ -187,17 +188,12 @@ struct AddEintIntOpPattern : public CrtOpPattern<FHE::AddEintIntOp> {
: CrtOpPattern<FHE::AddEintIntOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::AddEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::AddEintIntOp op, FHE::AddEintIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
intOperand.setType(converter.convertType(intOperand.getType()));
eintOperand.setType(converter.convertType(eintOperand.getType()));
mlir::Value eintOperand = adaptor.a();
mlir::Value intOperand = adaptor.b();
// Write plaintext encoding
mlir::Value encodedPlaintextTensor =
@@ -205,7 +201,7 @@ struct AddEintIntOpPattern : public CrtOpPattern<FHE::AddEintIntOp> {
// Write add loop.
mlir::Type ciphertextScalarType =
converter.convertType(eintOperand.getType())
converter->convertType(eintOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeUnaryTensorLoop(
@@ -239,17 +235,12 @@ struct SubIntEintOpPattern : public CrtOpPattern<FHE::SubIntEintOp> {
: CrtOpPattern<FHE::SubIntEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubIntEintOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::SubIntEintOp op, FHE::SubIntEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value intOperand = op.a();
mlir::Value eintOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
intOperand.setType(converter.convertType(intOperand.getType()));
eintOperand.setType(converter.convertType(eintOperand.getType()));
mlir::Value intOperand = adaptor.a();
mlir::Value eintOperand = adaptor.b();
// Write plaintext encoding
mlir::Value encodedPlaintextTensor =
@@ -257,7 +248,7 @@ struct SubIntEintOpPattern : public CrtOpPattern<FHE::SubIntEintOp> {
// Write add loop.
mlir::Type ciphertextScalarType =
converter.convertType(eintOperand.getType())
converter->convertType(eintOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeUnaryTensorLoop(
@@ -291,17 +282,12 @@ struct SubEintIntOpPattern : public CrtOpPattern<FHE::SubEintIntOp> {
: CrtOpPattern<FHE::SubEintIntOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::SubEintIntOp op, FHE::SubEintIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
intOperand.setType(converter.convertType(intOperand.getType()));
eintOperand.setType(converter.convertType(eintOperand.getType()));
mlir::Value eintOperand = adaptor.a();
mlir::Value intOperand = adaptor.b();
// Write plaintext negation
mlir::Type intType = intOperand.getType();
@@ -319,7 +305,7 @@ struct SubEintIntOpPattern : public CrtOpPattern<FHE::SubEintIntOp> {
// Write add loop.
mlir::Type ciphertextScalarType =
converter.convertType(eintOperand.getType())
converter->convertType(eintOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeUnaryTensorLoop(
@@ -353,21 +339,16 @@ struct AddEintOpPattern : CrtOpPattern<FHE::AddEintOp> {
: CrtOpPattern<FHE::AddEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::AddEintOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value lhsOperand = op.a();
mlir::Value rhsOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
lhsOperand.setType(converter.convertType(lhsOperand.getType()));
rhsOperand.setType(converter.convertType(rhsOperand.getType()));
mlir::Value lhsOperand = adaptor.a();
mlir::Value rhsOperand = adaptor.b();
// Write add loop.
mlir::Type ciphertextScalarType =
converter.convertType(lhsOperand.getType())
converter->convertType(lhsOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeUnaryTensorLoop(
@@ -401,21 +382,16 @@ struct SubEintOpPattern : CrtOpPattern<FHE::SubEintOp> {
: CrtOpPattern<FHE::SubEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubEintOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::SubEintOp op, FHE::SubEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value lhsOperand = op.a();
mlir::Value rhsOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
lhsOperand.setType(converter.convertType(lhsOperand.getType()));
rhsOperand.setType(converter.convertType(rhsOperand.getType()));
mlir::Value lhsOperand = adaptor.a();
mlir::Value rhsOperand = adaptor.b();
// Write sub loop.
mlir::Type ciphertextScalarType =
converter.convertType(lhsOperand.getType())
converter->convertType(lhsOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeUnaryTensorLoop(
@@ -451,18 +427,14 @@ struct NegEintOpPattern : CrtOpPattern<FHE::NegEintOp> {
: CrtOpPattern<FHE::NegEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::NegEintOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::NegEintOp op, FHE::NegEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value operand = op.a();
// Convert operand type to glwe tensor.
typing::TypeConverter converter{loweringParameters};
operand.setType(converter.convertType(operand.getType()));
mlir::Value operand = adaptor.a();
// Write the loop nest.
mlir::Type ciphertextScalarType = converter.convertType(operand.getType())
mlir::Type ciphertextScalarType = converter->convertType(operand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value loopRes = writeUnaryTensorLoop(
@@ -494,16 +466,12 @@ struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
: CrtOpPattern<FHE::MulEintIntOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::MulEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::MulEintIntOp op, FHE::MulEintIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter{loweringParameters};
eintOperand.setType(converter.convertType(eintOperand.getType()));
mlir::Value eintOperand = adaptor.a();
mlir::Value intOperand = adaptor.b();
// Write cleartext "encoding"
mlir::Value encodedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
@@ -511,7 +479,7 @@ struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
// Write the loop nest.
mlir::Type ciphertextScalarType =
converter.convertType(eintOperand.getType())
converter->convertType(eintOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value loopRes = writeUnaryTensorLoop(
@@ -545,9 +513,9 @@ struct ApplyLookupTableEintOpPattern
::mlir::LogicalResult
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
mlir::PatternRewriter &rewriter) const override {
typing::TypeConverter converter(loweringParameters);
FHE::ApplyLookupTableEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::Value newLut =
rewriter
@@ -556,7 +524,7 @@ struct ApplyLookupTableEintOpPattern
mlir::RankedTensorType::get(
mlir::ArrayRef<int64_t>(loweringParameters.lutSize),
rewriter.getI64Type()),
op.lut(),
adaptor.lut(),
rewriter.getI64ArrayAttr(
mlir::ArrayRef<int64_t>(loweringParameters.mods)),
rewriter.getI64ArrayAttr(
@@ -567,11 +535,9 @@ struct ApplyLookupTableEintOpPattern
// Replace the lut with an encoded / expanded one.
auto wopPBS = rewriter.create<TFHE::WopPBSGLWEOp>(
op.getLoc(), op.getType(), op.a(), newLut, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, rewriter.getI64ArrayAttr({}));
op.getLoc(), converter->convertType(op.getType()), adaptor.a(), newLut,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, rewriter.getI64ArrayAttr({}));
concretelang::convertOperandAndResultTypes(rewriter, wopPBS,
converter.getConversionLambda());
rewriter.replaceOp(op, {wopPBS.getResult()});
return ::mlir::success();
};
@@ -587,7 +553,9 @@ struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::ExtractOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::tensor::ExtractOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
if (!op.getTensor()
.getType()
@@ -601,7 +569,7 @@ struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
.isa<TFHE::GLWECipherTextType>()) {
return mlir::success();
}
typing::TypeConverter converter{loweringParameters};
mlir::SmallVector<mlir::OpFoldResult> offsets;
mlir::SmallVector<mlir::OpFoldResult> sizes;
mlir::SmallVector<mlir::OpFoldResult> strides;
@@ -617,11 +585,10 @@ struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
strides.push_back(rewriter.getI64IntegerAttr(1));
auto newOp = rewriter.create<mlir::tensor::ExtractSliceOp>(
op.getLoc(),
converter.convertType(op.getResult().getType())
converter->convertType(op.getResult().getType())
.cast<mlir::RankedTensorType>(),
op.getTensor(), offsets, sizes, strides);
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
adaptor.getTensor(), offsets, sizes, strides);
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
}
@@ -637,8 +604,8 @@ struct TensorInsertOpPattern : public CrtOpPattern<mlir::tensor::InsertOp> {
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::InsertOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::tensor::InsertOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (!op.getDest()
.getType()
.cast<mlir::TensorType>()
@@ -651,7 +618,7 @@ struct TensorInsertOpPattern : public CrtOpPattern<mlir::tensor::InsertOp> {
.isa<TFHE::GLWECipherTextType>()) {
return mlir::success();
}
typing::TypeConverter converter{loweringParameters};
mlir::SmallVector<mlir::OpFoldResult> offsets;
mlir::SmallVector<mlir::OpFoldResult> sizes;
mlir::SmallVector<mlir::OpFoldResult> strides;
@@ -666,9 +633,9 @@ struct TensorInsertOpPattern : public CrtOpPattern<mlir::tensor::InsertOp> {
sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods));
strides.push_back(rewriter.getI64IntegerAttr(1));
auto newOp = rewriter.create<mlir::tensor::InsertSliceOp>(
op.getLoc(), op.getScalar(), op.getDest(), offsets, sizes, strides);
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
op.getLoc(), adaptor.getScalar(), op.getDest(), offsets, sizes,
strides);
rewriter.replaceOp(op, {newOp});
return mlir::success();
}
@@ -685,7 +652,10 @@ struct TensorFromElementsOpPattern
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::FromElementsOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::tensor::FromElementsOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
if (!op.getResult()
.getType()
.cast<mlir::RankedTensorType>()
@@ -699,13 +669,11 @@ struct TensorFromElementsOpPattern
return mlir::success();
}
typing::TypeConverter converter{loweringParameters};
// Create dest tensor allocation op
mlir::Value outputTensor =
rewriter.create<mlir::bufferization::AllocTensorOp>(
op.getLoc(),
converter.convertType(op.getResult().getType())
converter->convertType(op.getResult().getType())
.cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
@@ -723,15 +691,13 @@ struct TensorFromElementsOpPattern
strides.push_back(rewriter.getI64IntegerAttr(1));
offsets.push_back(rewriter.getI64IntegerAttr(0));
}
for (size_t insertionIndex = 0; insertionIndex < op.getElements().size();
++insertionIndex) {
for (size_t insertionIndex = 0;
insertionIndex < adaptor.getElements().size(); ++insertionIndex) {
offsets[0] = rewriter.getI64IntegerAttr(insertionIndex);
mlir::tensor::InsertSliceOp insertOp =
rewriter.create<mlir::tensor::InsertSliceOp>(
op.getLoc(), op.getElements()[insertionIndex], outputTensor,
op.getLoc(), adaptor.getElements()[insertionIndex], outputTensor,
offsets, sizes, strides);
concretelang::convertOperandAndResultTypes(
rewriter, insertOp, converter.getConversionLambda());
outputTensor = insertOp.getResult();
}
rewriter.replaceOp(op, {outputTensor});
@@ -780,7 +746,11 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
op, converter);
});
target.addLegalOp<mlir::func::CallOp>();
target.addLegalOp<mlir::bufferization::AllocTensorOp>();
concretelang::addDynamicallyLegalTypeOp<mlir::bufferization::AllocTensorOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(target,
converter);
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::ExtractSliceOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::InsertSliceOp>(
@@ -812,15 +782,20 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
//---------------------------------------------------------- Adding patterns
mlir::RewritePatternSet patterns(&getContext());
// Patterns for `bufferization` dialect operations.
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::bufferization::AllocTensorOp, true>>(patterns.getContext(),
converter);
// Patterns for the `FHE` dialect operations
patterns.add<
// |_ `FHE::zero_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroEintOp,
TFHE::ZeroGLWEOp>,
concretelang::GenericOneToOneOpConversionPattern<FHE::ZeroEintOp,
TFHE::ZeroGLWEOp>,
// |_ `FHE::zero_tensor`
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroTensorOp,
TFHE::ZeroTensorGLWEOp>>(
&getContext(), converter);
concretelang::GenericOneToOneOpConversionPattern<
FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>>(&getContext(),
converter);
// |_ `FHE::add_eint_int`
patterns.add<lowering::AddEintIntOpPattern,
// |_ `FHE::add_eint`
@@ -841,37 +816,35 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
// Patterns for the relics of the `FHELinalg` dialect operations.
// |_ `linalg::generic` turned to nested `scf::for`
patterns.add<concretelang::GenericTypeConverterPattern<mlir::scf::ForOp>>(
patterns.getContext(), converter);
patterns.add<concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>>(
patterns.add<
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::ForOp>>(
patterns.getContext(), converter);
patterns.add<
RegionOpTypeConverterPattern<mlir::scf::ForOp, typing::TypeConverter>>(
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::YieldOp>>(
patterns.getContext(), converter);
patterns.add<
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::ForOp>>(
&getContext(), converter);
patterns.add<lowering::TensorExtractOpPattern>(&getContext(),
loweringParameters);
patterns.add<lowering::TensorInsertOpPattern>(&getContext(),
loweringParameters);
patterns.add<concretelang::GenericTypeConverterPattern<
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::ExtractSliceOp>>(patterns.getContext(), converter);
patterns.add<
concretelang::GenericTypeConverterPattern<mlir::tensor::InsertSliceOp>>(
patterns.getContext(), converter);
patterns.add<concretelang::GenericTypeConverterPattern<
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::InsertSliceOp>>(patterns.getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::CollapseShapeOp>>(patterns.getContext(), converter);
patterns.add<
concretelang::GenericTypeConverterPattern<mlir::tensor::ExpandShapeOp>>(
patterns.getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
typing::TypeConverter>>(
&getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::ExpandShapeOp>>(patterns.getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::GenerateOp, true>>(&getContext(), converter);
// Patterns for `func` dialect operations.
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
patterns.getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::func::ReturnOp>>(patterns.getContext(), converter);
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
&getContext(), converter);
@@ -880,26 +853,27 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
loweringParameters);
// Patterns for the `RT` dialect operations.
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::MakeReadyFutureOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::AwaitFutureOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::CreateAsyncTaskOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::BuildReturnPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::DerefReturnPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::WorkFunctionReturnOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
patterns.add<
// concretelang::TypeConvertingReinstantiationPattern<
// mlir::func::ReturnOp>,
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::YieldOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::MakeReadyFutureOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::AwaitFutureOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::CreateAsyncTaskOp, true>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::BuildReturnPtrPlaceholderOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::DerefReturnPtrPlaceholderOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::WorkFunctionReturnOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
//--------------------------------------------------------- Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))

View File

@@ -16,8 +16,8 @@
#include "concretelang/Conversion/FHEToTFHEScalar/Pass.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
@@ -110,16 +110,17 @@ namespace lowering {
/// A pattern rewriter superclass used by most op rewriters during the
/// conversion.
template <typename T>
struct ScalarOpPattern : public mlir::OpRewritePattern<T> {
struct ScalarOpPattern : public mlir::OpConversionPattern<T> {
ScalarOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<T>(context, benefit) {}
ScalarOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpConversionPattern<T>(converter, context, benefit) {}
/// Writes the encoding of a plaintext of arbitrary precision using shift.
mlir::Value
writePlaintextShiftEncoding(mlir::Location location, mlir::Value rawPlaintext,
int64_t encryptedWidth,
mlir::PatternRewriter &rewriter) const {
mlir::ConversionPatternRewriter &rewriter) const {
int64_t intShift = 64 - 1 - encryptedWidth;
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
location, rewriter.getIntegerType(64), rawPlaintext);
@@ -131,53 +132,25 @@ struct ScalarOpPattern : public mlir::OpRewritePattern<T> {
}
};
/// Rewriter for the `FHE::zero` operation.
struct ZeroEintOpPattern : public mlir::OpRewritePattern<FHE::ZeroEintOp> {
ZeroEintOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<FHE::ZeroEintOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::ZeroEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
typing::TypeConverter converter;
TFHE::ZeroGLWEOp newOp =
rewriter.create<TFHE::ZeroGLWEOp>(location, op.getType());
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
}
};
/// Rewriter for the `FHE::add_eint_int` operation.
struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
AddEintIntOpPattern(mlir::MLIRContext *context,
AddEintIntOpPattern(mlir::TypeConverter &converter,
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::AddEintIntOp>(context, benefit) {}
: ScalarOpPattern<FHE::AddEintIntOp>(converter, context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::AddEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
matchAndRewrite(FHE::AddEintIntOp op, FHE::AddEintIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
// Write the plaintext encoding
mlir::Value encodedInt = writePlaintextShiftEncoding(
op.getLoc(), intOperand,
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
rewriter);
op.getLoc(), adaptor.b(),
op.getType().cast<FHE::EncryptedIntegerType>().getWidth(), rewriter);
// Write the new op
auto newOp = rewriter.create<TFHE::AddGLWEIntOp>(location, op.getType(),
eintOperand, encodedInt);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.a(),
encodedInt);
return mlir::success();
}
@@ -185,13 +158,14 @@ struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
/// Rewriter for the `FHE::sub_eint_int` operation.
struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
SubEintIntOpPattern(mlir::MLIRContext *context,
SubEintIntOpPattern(mlir::TypeConverter &converter,
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::SubEintIntOp>(context, benefit) {}
: ScalarOpPattern<FHE::SubEintIntOp>(converter, context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::SubEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::SubEintIntOp op, FHE::SubEintIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
@@ -213,15 +187,9 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
rewriter);
// Write the new op
auto newOp = rewriter.create<TFHE::AddGLWEIntOp>(location, op.getType(),
eintOperand, encodedInt);
typing::TypeConverter converter;
// Convert the types
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.a(),
encodedInt);
return mlir::success();
};
@@ -229,31 +197,24 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
/// Rewriter for the `FHE::sub_int_eint` operation.
struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
SubIntEintOpPattern(mlir::MLIRContext *context,
SubIntEintOpPattern(mlir::TypeConverter &converter,
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::SubIntEintOp>(context, benefit) {}
: ScalarOpPattern<FHE::SubIntEintOp>(converter, context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::SubIntEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value intOperand = op.a();
mlir::Value eintOperand = op.b();
matchAndRewrite(FHE::SubIntEintOp op, FHE::SubIntEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
// Write the plaintext encoding
mlir::Value encodedInt = writePlaintextShiftEncoding(
op.getLoc(), intOperand,
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
op.getLoc(), adaptor.a(),
op.b().getType().cast<FHE::EncryptedIntegerType>().getWidth(),
rewriter);
// Write the new op
auto newOp = rewriter.create<TFHE::SubGLWEIntOp>(location, op.getType(),
encodedInt, eintOperand);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
rewriter.replaceOpWithNewOp<TFHE::SubGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), encodedInt,
adaptor.b());
return mlir::success();
};
@@ -261,60 +222,53 @@ struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
/// Rewriter for the `FHE::sub_eint` operation.
struct SubEintOpPattern : public ScalarOpPattern<FHE::SubEintOp> {
SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::SubEintOp>(context, benefit) {}
SubEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::SubEintOp>(converter, context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::SubEintOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::SubEintOp op, FHE::SubEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value lhsOperand = op.a();
mlir::Value rhsOperand = op.b();
mlir::Value lhsOperand = adaptor.a();
mlir::Value rhsOperand = adaptor.b();
// Write rhs negation
auto negative = rewriter.create<TFHE::NegGLWEOp>(
location, rhsOperand.getType(), rhsOperand);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, negative,
converter.getConversionLambda());
// Write new op.
auto newOp = rewriter.create<TFHE::AddGLWEOp>(
location, op.getType(), lhsOperand, negative.getResult());
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOpWithNewOp<TFHE::AddGLWEOp>(
op, getTypeConverter()->convertType(op.getType()), lhsOperand,
negative.getResult());
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
};
};
/// Rewriter for the `FHE::mul_eint_int` operation.
struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
MulEintIntOpPattern(mlir::MLIRContext *context,
MulEintIntOpPattern(mlir::TypeConverter &converter,
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::MulEintIntOp>(context, benefit) {}
: ScalarOpPattern<FHE::MulEintIntOp>(converter, context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::MulEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(FHE::MulEintIntOp op, FHE::MulEintIntOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
mlir::Value eintOperand = adaptor.a();
mlir::Value intOperand = adaptor.b();
// Write the cleartext "encoding"
mlir::Value castedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
location, rewriter.getIntegerType(64), intOperand);
// Write the new op.
auto newOp = rewriter.create<TFHE::MulGLWEIntOp>(
location, op.getType(), eintOperand, castedCleartext);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
rewriter.replaceOpWithNewOp<TFHE::MulGLWEIntOp>(
op, getTypeConverter()->convertType(op.getType()), eintOperand,
castedCleartext);
return mlir::success();
}
@@ -324,15 +278,17 @@ struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
struct ApplyLookupTableEintOpPattern
: public ScalarOpPattern<FHE::ApplyLookupTableEintOp> {
ApplyLookupTableEintOpPattern(
mlir::MLIRContext *context,
mlir::TypeConverter &converter, mlir::MLIRContext *context,
concretelang::ScalarLoweringParameters loweringParams,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::ApplyLookupTableEintOp>(context, benefit),
: ScalarOpPattern<FHE::ApplyLookupTableEintOp>(converter, context,
benefit),
loweringParameters(loweringParams) {}
mlir::LogicalResult
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
mlir::PatternRewriter &rewriter) const override {
FHE::ApplyLookupTableEintOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
size_t outputBits =
op.getResult().getType().cast<FHE::EncryptedIntegerType>().getWidth();
@@ -350,16 +306,12 @@ struct ApplyLookupTableEintOpPattern
// Insert keyswitch
auto ksOp = rewriter.create<TFHE::KeySwitchGLWEOp>(
op.getLoc(), op.a().getType(), op.a(), -1, -1);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, ksOp,
converter.getConversionLambda());
op.getLoc(), adaptor.a().getType(), adaptor.a(), -1, -1);
// Insert bootstrap
auto bsOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
op, op.getType(), ksOp, newLut, -1, -1, -1, -1);
concretelang::convertOperandAndResultTypes(rewriter, bsOp,
converter.getConversionLambda());
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
op, getTypeConverter()->convertType(op.getType()), ksOp, newLut, -1, -1,
-1, -1);
return mlir::success();
};
@@ -473,20 +425,20 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
// Patterns for the `FHE` dialect operations
patterns.add<
// |_ `FHE::zero_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroEintOp,
TFHE::ZeroGLWEOp>,
concretelang::GenericOneToOneOpConversionPattern<FHE::ZeroEintOp,
TFHE::ZeroGLWEOp>,
// |_ `FHE::zero_tensor`
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroTensorOp,
TFHE::ZeroTensorGLWEOp>,
concretelang::GenericOneToOneOpConversionPattern<
FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>,
// |_ `FHE::neg_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::NegEintOp,
TFHE::NegGLWEOp>,
concretelang::GenericOneToOneOpConversionPattern<FHE::NegEintOp,
TFHE::NegGLWEOp>,
// |_ `FHE::not`
concretelang::GenericTypeAndOpConverterPattern<FHE::BoolNotOp,
TFHE::NegGLWEOp>,
concretelang::GenericOneToOneOpConversionPattern<FHE::BoolNotOp,
TFHE::NegGLWEOp>,
// |_ `FHE::add_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::AddEintOp,
TFHE::AddGLWEOp>>(
concretelang::GenericOneToOneOpConversionPattern<FHE::AddEintOp,
TFHE::AddGLWEOp>>(
&getContext(), converter);
// |_ `FHE::add_eint_int`
patterns.add<lowering::AddEintIntOpPattern,
@@ -497,10 +449,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
// |_ `FHE::sub_eint`
lowering::SubEintOpPattern,
// |_ `FHE::mul_eint_int`
lowering::MulEintIntOpPattern>(&getContext());
lowering::MulEintIntOpPattern>(converter, &getContext());
// |_ `FHE::apply_lookup_table`
patterns.add<lowering::ApplyLookupTableEintOpPattern>(&getContext(),
loweringParameters);
patterns.add<lowering::ApplyLookupTableEintOpPattern>(
converter, &getContext(), loweringParameters);
// Patterns for boolean conversion ops
patterns.add<lowering::FromBoolOpPattern, lowering::ToBoolOpPattern>(
@@ -508,14 +460,12 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
// Patterns for the relics of the `FHELinalg` dialect operations.
// |_ `linalg::generic` turned to nested `scf::for`
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
patterns.getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
typing::TypeConverter>>(
&getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::linalg::YieldOp>>(patterns.getContext(), converter);
patterns.add<
RegionOpTypeConverterPattern<mlir::scf::ForOp, typing::TypeConverter>>(
concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::GenerateOp, true>,
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::ForOp>>(
&getContext(), converter);
concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
@@ -523,33 +473,46 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
// Patterns for `func` dialect operations.
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
patterns.getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::func::ReturnOp>>(patterns.getContext(), converter);
concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(target,
converter);
concretelang::addDynamicallyLegalTypeOp<mlir::scf::YieldOp>(target,
converter);
concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
converter);
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
&getContext(), converter);
// Patterns for `bufferization` dialect operations.
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::bufferization::AllocTensorOp, true>>(patterns.getContext(),
converter);
concretelang::addDynamicallyLegalTypeOp<mlir::bufferization::AllocTensorOp>(
target, converter);
// Patterns for the `RT` dialect operations.
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::MakeReadyFutureOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::AwaitFutureOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::CreateAsyncTaskOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::BuildReturnPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::DerefReturnPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::WorkFunctionReturnOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
patterns.add<
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::YieldOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::MakeReadyFutureOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::AwaitFutureOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::CreateAsyncTaskOp, true>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::BuildReturnPtrPlaceholderOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::DerefReturnPtrPlaceholderOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::WorkFunctionReturnOp>,
concretelang::TypeConvertingReinstantiationPattern<
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
//--------------------------------------------------------- Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))

View File

@@ -16,6 +16,7 @@
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Support/Constants.h"
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
namespace TFHE = mlir::concretelang::TFHE;
@@ -300,6 +301,12 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
// Add all patterns to convert TFHE types
populateWithTFHEOpTypeConversionPatterns(patterns, target, converter);
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::bufferization::AllocTensorOp>>(&getContext(), converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::bufferization::AllocTensorOp>(target, converter);
patterns.add<RegionOpTypeConverterPattern<
mlir::linalg::GenericOp, TFHEGlobalParametrizationTypeConverter>>(
&getContext(), converter);

View File

@@ -4,6 +4,7 @@
// for license information.
#include <iostream>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -231,6 +232,8 @@ void TFHEToConcretePass::runOnOperation() {
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::bufferization::AllocTensorOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
@@ -270,6 +273,8 @@ void TFHEToConcretePass::runOnOperation() {
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::linalg::YieldOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::bufferization::AllocTensorOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {

View File

@@ -0,0 +1 @@
add_subdirectory(Dialects)

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,40 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
#include "mlir/Transforms/RegionUtils.h"
#include <mlir/IR/BlockAndValueMapping.h>
namespace mlir {
namespace concretelang {
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
scf::ForOp oldOp, mlir::OpConversionPattern<scf::ForOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
// Create new for loop with empty body, but converted iter args
scf::ForOp newForOp = rewriter.replaceOpWithNewOp<scf::ForOp>(
oldOp, adaptor.getLowerBound(), adaptor.getUpperBound(),
adaptor.getStep(), adaptor.getInitArgs(),
[&](OpBuilder &builder, Location loc, Value iv, ValueRange args) {});
// Move operations from old for op to new one
auto &newOperations = newForOp.getBody()->getOperations();
mlir::Block *oldBody = oldOp.getBody();
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
oldBody->begin(), oldBody->end());
// Remap iter args and IV
for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(),
newForOp.getBody()->getArguments())) {
replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair),
newForOp.getRegion());
}
return mlir::success();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,107 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Conversion/Utils/Dialects/Tensor.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace concretelang {
//
// Specializations for CollapseShapeOp
//
// Specialization copying attributes not necessary, as the base
// template works correctly
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::CollapseShapeOp, false>::
matchAndRewrite(
tensor::CollapseShapeOp oldOp,
mlir::OpConversionPattern<tensor::CollapseShapeOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(),
oldOp.getReassociation());
return mlir::success();
}
//
// Specializations for FromElementsOp
//
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp, false>::
matchAndRewrite(
tensor::FromElementsOp oldOp,
mlir::OpConversionPattern<mlir::tensor::FromElementsOp>::OpAdaptor
adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type resultType = convertResultType(oldOp);
rewriter.replaceOpWithNewOp<mlir::tensor::FromElementsOp>(
oldOp, resultType, adaptor.getElements());
return mlir::success();
}
//
// Specializations for ExpandShapeOp
//
// Specialization copying attributes not necessary, as the base
// template works correctly
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::ExpandShapeOp, false>::
matchAndRewrite(
tensor::ExpandShapeOp oldOp,
mlir::OpConversionPattern<tensor::ExpandShapeOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(),
oldOp.getReassociation());
return mlir::success();
}
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::GenerateOp, true>::matchAndRewrite(
tensor::GenerateOp oldOp,
mlir::OpConversionPattern<tensor::GenerateOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
rewriter.setInsertionPointAfter(oldOp);
tensor::GenerateOp newGenerateOp = rewriter.create<tensor::GenerateOp>(
oldOp.getLoc(), resultTypes, adaptor.getOperands(), oldOp->getAttrs());
mlir::Block &oldBlock = oldOp.getBody().getBlocks().front();
mlir::Block &newBlock = newGenerateOp.getBody().getBlocks().front();
auto begin = oldBlock.begin();
auto nOps = oldBlock.getOperations().size();
newBlock.getOperations().splice(newBlock.getOperations().begin(),
oldBlock.getOperations(), begin,
std::next(begin, nOps - 1));
for (auto argsPair : llvm::zip(oldOp.getRegion().getArguments(),
newGenerateOp.getRegion().getArguments())) {
replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair),
newGenerateOp.getRegion());
}
rewriter.replaceOp(oldOp, newGenerateOp.getResult());
return mlir::success();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -15,6 +15,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include <concretelang/Conversion/Utils/FuncConstOpConversion.h>
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#include <concretelang/Conversion/Utils/Legality.h>
#include <llvm/IR/Instructions.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Bufferization/Transforms/Bufferize.h>

View File

@@ -2,7 +2,7 @@
// CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>, %arg1: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) {MANP = 2 : ui3} : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{_,_,_}{7}>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)

View File

@@ -2,7 +2,7 @@
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{7}>
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
@@ -11,7 +11,7 @@ func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{2}>
%1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool