From 73fd6c5fe751c502bcc4c7c68da4a79b753fa791 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Thu, 26 Jan 2023 14:54:32 +0100 Subject: [PATCH] refactor(compiler): FHE to TFHE: Use OpConversionPattern for dialect conversion Use `OpConversionPattern` instead of `OpRewritePattern` for operation conversion during dialect conversion. This makes explicit and in-place type conversions unnecessary, since `OpConversionPattern` already properly converts operand types and provides them to the rewrite rule through an operation adaptor. The main contributions of this commit are the two class templates `TypeConvertingReinstantiationPattern` and `GenericOneToOneOpConversionPattern`. The former allows for the definition of a simple replacement rule that re-instantiates an operation after the types of its operands have been converted. This is especially useful for type-polymorphic operations during dialect conversion. The latter allows for the definition of patterns, where one operation needs to be replaced with a different operation after conversion of its operands. The default implementations for the class templates provide conversions rules for operations that have a generic builder method that takes the desired return type(s), the operands and (optionally) a set of attributes. How attributes are discarded during a conversion (either by omitting the builder argument or by passing an empty set of attributes) can be defined through specialization of `ReinstantiationAttributeDismissalStrategy`. Custom replacement rules that deviate from the scheme above should be implemented by specializing `TypeConvertingReinstantiationPattern::matchAndRewrite()` and `GenericOneToOneOpConversionPattern::matchAndRewrite()`. --- .../Conversion/Utils/Dialects/SCF.h | 29 ++ .../Conversion/Utils/Dialects/Tensor.h | 70 +++++ .../Utils/GenericOpTypeConversionPattern.h | 10 - .../concretelang/Conversion/Utils/Legality.h | 26 ++ .../Utils/ReinstantiatingOpTypeConversion.h | 216 +++++++++++++ .../Conversion/Utils/TensorOpTypeConversion.h | 29 +- compiler/lib/Conversion/CMakeLists.txt | 10 +- .../ConcreteToBConcrete.cpp | 1 + .../Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp | 286 ++++++++---------- .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 271 +++++++---------- .../TFHEGlobalParametrization.cpp | 7 + .../TFHEToConcrete/TFHEToConcrete.cpp | 5 + compiler/lib/Conversion/Utils/CMakeLists.txt | 1 + .../Conversion/Utils/Dialects/CMakeLists.txt | 1 + .../lib/Conversion/Utils/Dialects/SCF.cpp | 40 +++ .../lib/Conversion/Utils/Dialects/Tensor.cpp | 107 +++++++ .../RT/Analysis/BufferizeDataflowTaskOps.cpp | 1 + .../Conversion/FHEToTFHEScalar/add_eint.mlir | 2 +- .../Conversion/FHEToTFHEScalar/neg_eint.mlir | 4 +- 19 files changed, 781 insertions(+), 335 deletions(-) create mode 100644 compiler/include/concretelang/Conversion/Utils/Dialects/SCF.h create mode 100644 compiler/include/concretelang/Conversion/Utils/Dialects/Tensor.h create mode 100644 compiler/include/concretelang/Conversion/Utils/Legality.h create mode 100644 compiler/include/concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h create mode 100644 compiler/lib/Conversion/Utils/CMakeLists.txt create mode 100644 compiler/lib/Conversion/Utils/Dialects/CMakeLists.txt create mode 100644 compiler/lib/Conversion/Utils/Dialects/SCF.cpp create mode 100644 compiler/lib/Conversion/Utils/Dialects/Tensor.cpp diff --git a/compiler/include/concretelang/Conversion/Utils/Dialects/SCF.h b/compiler/include/concretelang/Conversion/Utils/Dialects/SCF.h new file mode 100644 index 000000000..c97e878b6 --- /dev/null +++ b/compiler/include/concretelang/Conversion/Utils/Dialects/SCF.h @@ -0,0 +1,29 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CONVERSION_UTILS_DIALECTS_SCF_H_ +#define CONCRETELANG_CONVERSION_UTILS_DIALECTS_SCF_H_ + +#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +namespace mlir { +namespace concretelang { + +// +// Specializations for ForOp +// + +// Specialization copying attributes omitted +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern::matchAndRewrite( + scf::ForOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const; + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Conversion/Utils/Dialects/Tensor.h b/compiler/include/concretelang/Conversion/Utils/Dialects/Tensor.h new file mode 100644 index 000000000..d48ba7241 --- /dev/null +++ b/compiler/include/concretelang/Conversion/Utils/Dialects/Tensor.h @@ -0,0 +1,70 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_ +#define CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_ + +#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +namespace mlir { +namespace concretelang { + +// +// Specializations for CollapseShapeOp +// + +// Specialization copying attributes not necessary, as the base +// template works correctly +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern:: + matchAndRewrite( + tensor::CollapseShapeOp oldOp, + mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const; +// +// Specializations for FromElementsOp +// +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern:: + matchAndRewrite( + tensor::FromElementsOp oldOp, + mlir::OpConversionPattern::OpAdaptor + adaptor, + mlir::ConversionPatternRewriter &rewriter) const; + +// +// Specializations for ExpandShapeOp +// + +// Specialization copying attributes not necessary, as the base +// template works correctly + +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern:: + matchAndRewrite( + tensor::ExpandShapeOp oldOp, + mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const; + +// +// Specializations for GenerateOp +// + +// Specialization NOT copying attributes omitted +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern::matchAndRewrite( + tensor::GenerateOp oldOp, + mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const; + +} // namespace concretelang +} // namespace mlir + +#endif // CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_ diff --git a/compiler/include/concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h b/compiler/include/concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h index 5b2645bda..ce6f91bbf 100644 --- a/compiler/include/concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h +++ b/compiler/include/concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h @@ -96,16 +96,6 @@ struct GenericTypeAndOpConverterPattern : public mlir::OpRewritePattern { private: mlir::TypeConverter &converter; }; - -template -void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target, - mlir::TypeConverter &typeConverter) { - target.addDynamicallyLegalOp([&](Op op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); -} - } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Conversion/Utils/Legality.h b/compiler/include/concretelang/Conversion/Utils/Legality.h new file mode 100644 index 000000000..c74c8e81d --- /dev/null +++ b/compiler/include/concretelang/Conversion/Utils/Legality.h @@ -0,0 +1,26 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CONVERSION_UTILS_LEGALITY_H_ +#define CONCRETELANG_CONVERSION_UTILS_LEGALITY_H_ + +#include + +namespace mlir { +namespace concretelang { + +template +void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target, + mlir::TypeConverter &typeConverter) { + target.addDynamicallyLegalOp([&](Op op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); +} + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h b/compiler/include/concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h new file mode 100644 index 000000000..1697054e5 --- /dev/null +++ b/compiler/include/concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h @@ -0,0 +1,216 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CONVERSION_UTILS_REINSTANTIATINGOPTYPECONVERSION_H_ +#define CONCRETELANG_CONVERSION_UTILS_REINSTANTIATINGOPTYPECONVERSION_H_ + +#include + +namespace mlir { +namespace concretelang { + +// Set of types defining how attributes should be handled when +// invocating the build method of an operation upon reinstantiation +struct ReinstantiationAttributeHandling { + // Copy attributes + struct copy {}; + + // Completely dismiss attributes by not passing a set of arguments + // to the builder at all + struct dismiss {}; + + // Dismiss attributes by passing an empty set of arguments to the + // builder + struct pass_empty_vector {}; +}; + +// Template defining how attributes should be dismissed when invoking +// the build method of an operation upon reinstantiation. In the +// default case, the argument for attributes is simply dismissed. +template struct ReinstantiationAttributeDismissalStrategy { + typedef ReinstantiationAttributeHandling::dismiss strategy; +}; + +// Template defining how attributes should be copied when invoking the +// build method of an operation upon reinstantiation. In the default +// case, the argument for attributes is forwarded to the build method. +template struct ReinstantiationAttributeCopyStrategy { + typedef ReinstantiationAttributeHandling::copy strategy; +}; + +namespace { +// Class template that defines the attribute handling strategy for +// either dismissal of attributes (if `copyAttrsSwitch` is `false`) or copying +// attributes (if `copyAttrsSwitch` is `true`). +template struct AttributeHandlingSwitch {}; + +template struct AttributeHandlingSwitch { + typedef typename ReinstantiationAttributeCopyStrategy::strategy strategy; +}; + +template struct AttributeHandlingSwitch { + typedef + typename ReinstantiationAttributeDismissalStrategy::strategy strategy; +}; + +// Simple functor-like template invoking a rewriter with a variable +// set of arguments and an op's attributes as the last argument. +template +struct ReplaceOpWithNewOpCopyAttrs { + static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter, + mlir::Operation *op, mlir::TypeRange resultTypes, + mlir::ValueRange operands) { + return rewriter.replaceOpWithNewOp(op, resultTypes, operands, + op->getAttrs()); + } +}; + +// Simple functor-like template invoking a rewriter with a variable +// set of arguments dismissing the attributes passed as the last +// argument. +template +struct ReplaceOpWithNewOpDismissAttrs { + static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter, + mlir::Operation *op, mlir::TypeRange resultTypes, + mlir::ValueRange operands) { + return rewriter.replaceOpWithNewOp(op, resultTypes, operands); + } +}; + +// Simple functor-like template invoking a rewriter with a variable +// set of arguments dismissing the attributes by passing an empty +// set of arguments to the builder. +template +struct ReplaceOpWithNewOpEmptyAttrs { + static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter, + mlir::Operation *op, mlir::TypeRange resultTypes, + mlir::ValueRange operands) { + llvm::SmallVector attrs{}; + return rewriter.replaceOpWithNewOp(op, resultTypes, operands, + attrs); + } +}; + +// Functor-like template that either forwards to +// `ReplaceOpWithNewOpCopyAttrs` or `ReplaceOpWithNewOpDismissAttrs` +// depending on the value of `copyAttrs`. +template +struct ReplaceOpWithNewOpAttrSwitch {}; + +// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does copy +// attributes. +template +struct ReplaceOpWithNewOpAttrSwitch { + typedef ReplaceOpWithNewOpCopyAttrs instantiator; +}; + +// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does NOT copy +// attributes by not passing attributes to the builder at all. +template +struct ReplaceOpWithNewOpAttrSwitch { + typedef ReplaceOpWithNewOpDismissAttrs instantiator; +}; + +// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does NOT copy +// attributes by passing an empty set of attributes to the builder. +template +struct ReplaceOpWithNewOpAttrSwitch< + ReinstantiationAttributeHandling::pass_empty_vector, OpTy, Args...> { + typedef ReplaceOpWithNewOpEmptyAttrs instantiator; +}; + +} // namespace + +template +struct GenericOneToOneOpConversionPatternBase + : public mlir::OpConversionPattern { + GenericOneToOneOpConversionPatternBase(mlir::MLIRContext *context, + mlir::TypeConverter &converter, + mlir::PatternBenefit benefit = 100) + : mlir::OpConversionPattern(converter, context, benefit) {} + + mlir::SmallVector convertResultTypes(OldOp oldOp) const { + mlir::TypeConverter *converter = this->getTypeConverter(); + + // Convert result types + mlir::SmallVector resultTypes(oldOp->getNumResults()); + + for (unsigned i = 0; i < oldOp->getNumResults(); i++) { + auto result = oldOp->getResult(i); + resultTypes[i] = converter->convertType(result.getType()); + } + + return resultTypes; + } + + mlir::Type convertResultType(OldOp oldOp) const { + mlir::TypeConverter *converter = this->getTypeConverter(); + return converter->convertType(oldOp->getResult(0).getType()); + } +}; + +// Conversion pattern that replaces an instance of an operation of the type +// `OldOp` with an instance of the type `NewOp`, taking into account operands, +// return types and possible copying attributes (iff copyAttrs is `true`). +template +struct GenericOneToOneOpConversionPattern + : public GenericOneToOneOpConversionPatternBase { + GenericOneToOneOpConversionPattern(mlir::MLIRContext *context, + mlir::TypeConverter &converter, + mlir::PatternBenefit benefit = 100) + : GenericOneToOneOpConversionPatternBase( + context, converter, benefit) {} + + virtual mlir::LogicalResult + matchAndRewrite(OldOp oldOp, + typename mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::SmallVector resultTypes = this->convertResultTypes(oldOp); + + ReplaceOpWithNewOpAttrSwitch< + typename AttributeHandlingSwitch::strategy, + NewOp>::instantiator::replace(rewriter, oldOp, + mlir::TypeRange{resultTypes}, + mlir::ValueRange{adaptor.getOperands()}); + + return mlir::success(); + } +}; + +// Conversion pattern that retrieves the converted operands of an +// operation of the type `Op`, converts the types of the results of +// the operation and re-instantiates the operation type with the +// converted operands and result types. +template +struct TypeConvertingReinstantiationPattern + : public GenericOneToOneOpConversionPatternBase { + TypeConvertingReinstantiationPattern(mlir::MLIRContext *context, + mlir::TypeConverter &converter, + mlir::PatternBenefit benefit = 100) + : GenericOneToOneOpConversionPatternBase( + context, converter, benefit) {} + // Simple forward that makes the method specializable out of class + // directly for this class rather than for its base + virtual mlir::LogicalResult + matchAndRewrite(Op op, + typename mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::SmallVector resultTypes = this->convertResultTypes(op); + + ReplaceOpWithNewOpAttrSwitch< + typename AttributeHandlingSwitch::strategy, + Op>::instantiator::replace(rewriter, op, mlir::TypeRange{resultTypes}, + mlir::ValueRange{adaptor.getOperands()}); + + return mlir::success(); + } +}; + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Conversion/Utils/TensorOpTypeConversion.h b/compiler/include/concretelang/Conversion/Utils/TensorOpTypeConversion.h index a61f3399b..2d16c5924 100644 --- a/compiler/include/concretelang/Conversion/Utils/TensorOpTypeConversion.h +++ b/compiler/include/concretelang/Conversion/Utils/TensorOpTypeConversion.h @@ -10,7 +10,8 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" -#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" +#include "concretelang/Conversion/Utils/Dialects/Tensor.h" +#include "concretelang/Conversion/Utils/Legality.h" namespace mlir { namespace concretelang { @@ -20,37 +21,43 @@ populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, mlir::TypeConverter &typeConverter) { // ExtractOp - patterns.add>( + patterns.add>( patterns.getContext(), typeConverter); addDynamicallyLegalTypeOp(target, typeConverter); + // ExtractSliceOp - patterns.add>( + patterns.add< + TypeConvertingReinstantiationPattern>( patterns.getContext(), typeConverter); addDynamicallyLegalTypeOp(target, typeConverter); // InsertOp - patterns.add>( + patterns.add>( patterns.getContext(), typeConverter); addDynamicallyLegalTypeOp(target, typeConverter); // InsertSliceOp - patterns.add>( + patterns.add< + TypeConvertingReinstantiationPattern>( patterns.getContext(), typeConverter); addDynamicallyLegalTypeOp(target, typeConverter); // FromElementsOp - patterns.add>( - patterns.getContext(), typeConverter); + patterns + .add>( + patterns.getContext(), typeConverter); addDynamicallyLegalTypeOp(target, typeConverter); // TensorCollapseShapeOp - patterns.add>( - patterns.getContext(), typeConverter); + patterns + .add>( + patterns.getContext(), typeConverter); addDynamicallyLegalTypeOp(target, typeConverter); // TensorExpandShapeOp - patterns.add>( - patterns.getContext(), typeConverter); + patterns + .add>( + patterns.getContext(), typeConverter); addDynamicallyLegalTypeOp(target, typeConverter); } } // namespace concretelang diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index b30536202..999fc2107 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -9,5 +9,13 @@ add_subdirectory(SDFGToStreamEmulator) add_subdirectory(MLIRLowerableDialectsToLLVM) add_subdirectory(LinalgExtras) add_subdirectory(ExtractSDFGOps) +add_subdirectory(Utils) -add_mlir_library(ConcretelangConversion Tools.cpp LINK_LIBS PUBLIC MLIRIR) +add_mlir_library( + ConcretelangConversion + Tools.cpp + Utils/Dialects/SCF.cpp + Utils/Dialects/Tensor.cpp + LINK_LIBS + PUBLIC + MLIRIR) diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index 71168a4cc..ec32fd1cf 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -27,6 +27,7 @@ #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" +#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" diff --git a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index d29ba721a..4f8a08ffc 100644 --- a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -10,13 +10,15 @@ #include #include +#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" +#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/FHEToTFHECrt/Pass.h" #include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Utils/Dialects/SCF.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" -#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" @@ -112,7 +114,8 @@ namespace lowering { /// A pattern rewriter superclass used by most op rewriters during the /// conversion. -template struct CrtOpPattern : public mlir::OpRewritePattern { +template +struct CrtOpPattern : public mlir::OpConversionPattern { /// The lowering parameters are bound to the op rewriter. concretelang::CrtLoweringParameters loweringParameters; @@ -120,8 +123,8 @@ template struct CrtOpPattern : public mlir::OpRewritePattern { CrtOpPattern(mlir::MLIRContext *context, concretelang::CrtLoweringParameters params, mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit), - loweringParameters(params) {} + : mlir::OpConversionPattern(typeConverter, context, benefit), + loweringParameters(params), typeConverter(params) {} /// Writes an `scf::for` that loops over the crt dimension of one tensor and /// execute the input lambda to write the loop body. Returns the first result @@ -154,11 +157,6 @@ template struct CrtOpPattern : public mlir::OpRewritePattern { location, zeroConstantOp, crtSizeConstantOp, oneConstantOp, tensor, body); - // Convert the types of the new operation - typing::TypeConverter converter(loweringParameters); - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); - return newOp.getResult(0); } @@ -176,6 +174,9 @@ template struct CrtOpPattern : public mlir::OpRewritePattern { castedPlaintext, rewriter.getI64ArrayAttr(loweringParameters.mods), rewriter.getI64IntegerAttr(loweringParameters.modsProd)); } + +protected: + typing::TypeConverter typeConverter; }; /// Rewriter for the `FHE::add_eint_int` operation. @@ -187,17 +188,12 @@ struct AddEintIntOpPattern : public CrtOpPattern { : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult - matchAndRewrite(FHE::AddEintIntOp op, - mlir::PatternRewriter &rewriter) const override { - + matchAndRewrite(FHE::AddEintIntOp op, FHE::AddEintIntOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); - mlir::Value eintOperand = op.a(); - mlir::Value intOperand = op.b(); - - // Convert operand type to glwe tensor. - typing::TypeConverter converter(loweringParameters); - intOperand.setType(converter.convertType(intOperand.getType())); - eintOperand.setType(converter.convertType(eintOperand.getType())); + mlir::Value eintOperand = adaptor.a(); + mlir::Value intOperand = adaptor.b(); // Write plaintext encoding mlir::Value encodedPlaintextTensor = @@ -205,7 +201,7 @@ struct AddEintIntOpPattern : public CrtOpPattern { // Write add loop. mlir::Type ciphertextScalarType = - converter.convertType(eintOperand.getType()) + converter->convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( @@ -239,17 +235,12 @@ struct SubIntEintOpPattern : public CrtOpPattern { : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult - matchAndRewrite(FHE::SubIntEintOp op, - mlir::PatternRewriter &rewriter) const override { - + matchAndRewrite(FHE::SubIntEintOp op, FHE::SubIntEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); - mlir::Value intOperand = op.a(); - mlir::Value eintOperand = op.b(); - - // Convert operand type to glwe tensor. - typing::TypeConverter converter(loweringParameters); - intOperand.setType(converter.convertType(intOperand.getType())); - eintOperand.setType(converter.convertType(eintOperand.getType())); + mlir::Value intOperand = adaptor.a(); + mlir::Value eintOperand = adaptor.b(); // Write plaintext encoding mlir::Value encodedPlaintextTensor = @@ -257,7 +248,7 @@ struct SubIntEintOpPattern : public CrtOpPattern { // Write add loop. mlir::Type ciphertextScalarType = - converter.convertType(eintOperand.getType()) + converter->convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( @@ -291,17 +282,12 @@ struct SubEintIntOpPattern : public CrtOpPattern { : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult - matchAndRewrite(FHE::SubEintIntOp op, - mlir::PatternRewriter &rewriter) const override { - + matchAndRewrite(FHE::SubEintIntOp op, FHE::SubEintIntOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); - mlir::Value eintOperand = op.a(); - mlir::Value intOperand = op.b(); - - // Convert operand type to glwe tensor. - typing::TypeConverter converter(loweringParameters); - intOperand.setType(converter.convertType(intOperand.getType())); - eintOperand.setType(converter.convertType(eintOperand.getType())); + mlir::Value eintOperand = adaptor.a(); + mlir::Value intOperand = adaptor.b(); // Write plaintext negation mlir::Type intType = intOperand.getType(); @@ -319,7 +305,7 @@ struct SubEintIntOpPattern : public CrtOpPattern { // Write add loop. mlir::Type ciphertextScalarType = - converter.convertType(eintOperand.getType()) + converter->convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( @@ -353,21 +339,16 @@ struct AddEintOpPattern : CrtOpPattern { : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult - matchAndRewrite(FHE::AddEintOp op, - mlir::PatternRewriter &rewriter) const override { - + matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); - mlir::Value lhsOperand = op.a(); - mlir::Value rhsOperand = op.b(); - - // Convert operand type to glwe tensor. - typing::TypeConverter converter(loweringParameters); - lhsOperand.setType(converter.convertType(lhsOperand.getType())); - rhsOperand.setType(converter.convertType(rhsOperand.getType())); + mlir::Value lhsOperand = adaptor.a(); + mlir::Value rhsOperand = adaptor.b(); // Write add loop. mlir::Type ciphertextScalarType = - converter.convertType(lhsOperand.getType()) + converter->convertType(lhsOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( @@ -401,21 +382,16 @@ struct SubEintOpPattern : CrtOpPattern { : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult - matchAndRewrite(FHE::SubEintOp op, - mlir::PatternRewriter &rewriter) const override { - + matchAndRewrite(FHE::SubEintOp op, FHE::SubEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); - mlir::Value lhsOperand = op.a(); - mlir::Value rhsOperand = op.b(); - - // Convert operand type to glwe tensor. - typing::TypeConverter converter(loweringParameters); - lhsOperand.setType(converter.convertType(lhsOperand.getType())); - rhsOperand.setType(converter.convertType(rhsOperand.getType())); + mlir::Value lhsOperand = adaptor.a(); + mlir::Value rhsOperand = adaptor.b(); // Write sub loop. mlir::Type ciphertextScalarType = - converter.convertType(lhsOperand.getType()) + converter->convertType(lhsOperand.getType()) .cast() .getElementType(); mlir::Value output = writeUnaryTensorLoop( @@ -451,18 +427,14 @@ struct NegEintOpPattern : CrtOpPattern { : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult - matchAndRewrite(FHE::NegEintOp op, - mlir::PatternRewriter &rewriter) const override { - + matchAndRewrite(FHE::NegEintOp op, FHE::NegEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); - mlir::Value operand = op.a(); - - // Convert operand type to glwe tensor. - typing::TypeConverter converter{loweringParameters}; - operand.setType(converter.convertType(operand.getType())); + mlir::Value operand = adaptor.a(); // Write the loop nest. - mlir::Type ciphertextScalarType = converter.convertType(operand.getType()) + mlir::Type ciphertextScalarType = converter->convertType(operand.getType()) .cast() .getElementType(); mlir::Value loopRes = writeUnaryTensorLoop( @@ -494,16 +466,12 @@ struct MulEintIntOpPattern : CrtOpPattern { : CrtOpPattern(context, params, benefit) {} ::mlir::LogicalResult - matchAndRewrite(FHE::MulEintIntOp op, - mlir::PatternRewriter &rewriter) const override { - + matchAndRewrite(FHE::MulEintIntOp op, FHE::MulEintIntOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Location location = op.getLoc(); - mlir::Value eintOperand = op.a(); - mlir::Value intOperand = op.b(); - - // Convert operand type to glwe tensor. - typing::TypeConverter converter{loweringParameters}; - eintOperand.setType(converter.convertType(eintOperand.getType())); + mlir::Value eintOperand = adaptor.a(); + mlir::Value intOperand = adaptor.b(); // Write cleartext "encoding" mlir::Value encodedCleartext = rewriter.create( @@ -511,7 +479,7 @@ struct MulEintIntOpPattern : CrtOpPattern { // Write the loop nest. mlir::Type ciphertextScalarType = - converter.convertType(eintOperand.getType()) + converter->convertType(eintOperand.getType()) .cast() .getElementType(); mlir::Value loopRes = writeUnaryTensorLoop( @@ -545,9 +513,9 @@ struct ApplyLookupTableEintOpPattern ::mlir::LogicalResult matchAndRewrite(FHE::ApplyLookupTableEintOp op, - mlir::PatternRewriter &rewriter) const override { - - typing::TypeConverter converter(loweringParameters); + FHE::ApplyLookupTableEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); mlir::Value newLut = rewriter @@ -556,7 +524,7 @@ struct ApplyLookupTableEintOpPattern mlir::RankedTensorType::get( mlir::ArrayRef(loweringParameters.lutSize), rewriter.getI64Type()), - op.lut(), + adaptor.lut(), rewriter.getI64ArrayAttr( mlir::ArrayRef(loweringParameters.mods)), rewriter.getI64ArrayAttr( @@ -567,11 +535,9 @@ struct ApplyLookupTableEintOpPattern // Replace the lut with an encoded / expanded one. auto wopPBS = rewriter.create( - op.getLoc(), op.getType(), op.a(), newLut, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, rewriter.getI64ArrayAttr({})); + op.getLoc(), converter->convertType(op.getType()), adaptor.a(), newLut, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, rewriter.getI64ArrayAttr({})); - concretelang::convertOperandAndResultTypes(rewriter, wopPBS, - converter.getConversionLambda()); rewriter.replaceOp(op, {wopPBS.getResult()}); return ::mlir::success(); }; @@ -587,7 +553,9 @@ struct TensorExtractOpPattern : public CrtOpPattern { ::mlir::LogicalResult matchAndRewrite(mlir::tensor::ExtractOp op, - mlir::PatternRewriter &rewriter) const override { + mlir::tensor::ExtractOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); if (!op.getTensor() .getType() @@ -601,7 +569,7 @@ struct TensorExtractOpPattern : public CrtOpPattern { .isa()) { return mlir::success(); } - typing::TypeConverter converter{loweringParameters}; + mlir::SmallVector offsets; mlir::SmallVector sizes; mlir::SmallVector strides; @@ -617,11 +585,10 @@ struct TensorExtractOpPattern : public CrtOpPattern { strides.push_back(rewriter.getI64IntegerAttr(1)); auto newOp = rewriter.create( op.getLoc(), - converter.convertType(op.getResult().getType()) + converter->convertType(op.getResult().getType()) .cast(), - op.getTensor(), offsets, sizes, strides); - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); + adaptor.getTensor(), offsets, sizes, strides); + rewriter.replaceOp(op, {newOp.getResult()}); return mlir::success(); } @@ -637,8 +604,8 @@ struct TensorInsertOpPattern : public CrtOpPattern { ::mlir::LogicalResult matchAndRewrite(mlir::tensor::InsertOp op, - mlir::PatternRewriter &rewriter) const override { - + mlir::tensor::InsertOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { if (!op.getDest() .getType() .cast() @@ -651,7 +618,7 @@ struct TensorInsertOpPattern : public CrtOpPattern { .isa()) { return mlir::success(); } - typing::TypeConverter converter{loweringParameters}; + mlir::SmallVector offsets; mlir::SmallVector sizes; mlir::SmallVector strides; @@ -666,9 +633,9 @@ struct TensorInsertOpPattern : public CrtOpPattern { sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods)); strides.push_back(rewriter.getI64IntegerAttr(1)); auto newOp = rewriter.create( - op.getLoc(), op.getScalar(), op.getDest(), offsets, sizes, strides); - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); + op.getLoc(), adaptor.getScalar(), op.getDest(), offsets, sizes, + strides); + rewriter.replaceOp(op, {newOp}); return mlir::success(); } @@ -685,7 +652,10 @@ struct TensorFromElementsOpPattern ::mlir::LogicalResult matchAndRewrite(mlir::tensor::FromElementsOp op, - mlir::PatternRewriter &rewriter) const override { + mlir::tensor::FromElementsOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::TypeConverter *converter = this->getTypeConverter(); + if (!op.getResult() .getType() .cast() @@ -699,13 +669,11 @@ struct TensorFromElementsOpPattern return mlir::success(); } - typing::TypeConverter converter{loweringParameters}; - // Create dest tensor allocation op mlir::Value outputTensor = rewriter.create( op.getLoc(), - converter.convertType(op.getResult().getType()) + converter->convertType(op.getResult().getType()) .cast(), mlir::ValueRange{}); @@ -723,15 +691,13 @@ struct TensorFromElementsOpPattern strides.push_back(rewriter.getI64IntegerAttr(1)); offsets.push_back(rewriter.getI64IntegerAttr(0)); } - for (size_t insertionIndex = 0; insertionIndex < op.getElements().size(); - ++insertionIndex) { + for (size_t insertionIndex = 0; + insertionIndex < adaptor.getElements().size(); ++insertionIndex) { offsets[0] = rewriter.getI64IntegerAttr(insertionIndex); mlir::tensor::InsertSliceOp insertOp = rewriter.create( - op.getLoc(), op.getElements()[insertionIndex], outputTensor, + op.getLoc(), adaptor.getElements()[insertionIndex], outputTensor, offsets, sizes, strides); - concretelang::convertOperandAndResultTypes( - rewriter, insertOp, converter.getConversionLambda()); outputTensor = insertOp.getResult(); } rewriter.replaceOp(op, {outputTensor}); @@ -780,7 +746,11 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { op, converter); }); target.addLegalOp(); - target.addLegalOp(); + + concretelang::addDynamicallyLegalTypeOp( + target, converter); + concretelang::addDynamicallyLegalTypeOp(target, + converter); concretelang::addDynamicallyLegalTypeOp( target, converter); concretelang::addDynamicallyLegalTypeOp( @@ -812,15 +782,20 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { //---------------------------------------------------------- Adding patterns mlir::RewritePatternSet patterns(&getContext()); + // Patterns for `bufferization` dialect operations. + patterns.add>(patterns.getContext(), + converter); + // Patterns for the `FHE` dialect operations patterns.add< // |_ `FHE::zero_eint` - concretelang::GenericTypeAndOpConverterPattern, + concretelang::GenericOneToOneOpConversionPattern, // |_ `FHE::zero_tensor` - concretelang::GenericTypeAndOpConverterPattern>( - &getContext(), converter); + concretelang::GenericOneToOneOpConversionPattern< + FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>>(&getContext(), + converter); // |_ `FHE::add_eint_int` patterns.add { // Patterns for the relics of the `FHELinalg` dialect operations. // |_ `linalg::generic` turned to nested `scf::for` - patterns.add>( - patterns.getContext(), converter); - patterns.add>( + patterns.add< + concretelang::TypeConvertingReinstantiationPattern>( patterns.getContext(), converter); patterns.add< - RegionOpTypeConverterPattern>( + concretelang::TypeConvertingReinstantiationPattern>( + patterns.getContext(), converter); + patterns.add< + concretelang::TypeConvertingReinstantiationPattern>( &getContext(), converter); patterns.add(&getContext(), loweringParameters); patterns.add(&getContext(), loweringParameters); - patterns.add>(patterns.getContext(), converter); - patterns.add< - concretelang::GenericTypeConverterPattern>( - patterns.getContext(), converter); - patterns.add>(patterns.getContext(), converter); + patterns.add>(patterns.getContext(), converter); - patterns.add< - concretelang::GenericTypeConverterPattern>( - patterns.getContext(), converter); - patterns.add>( - &getContext(), converter); + patterns.add>(patterns.getContext(), converter); + patterns.add>(&getContext(), converter); // Patterns for `func` dialect operations. mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); - patterns - .add>( - patterns.getContext(), converter); + patterns.add>(patterns.getContext(), converter); patterns.add>( &getContext(), converter); @@ -880,26 +853,27 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { loweringParameters); // Patterns for the `RT` dialect operations. - patterns - .add, - concretelang::GenericTypeConverterPattern, - concretelang::GenericTypeConverterPattern< - concretelang::RT::MakeReadyFutureOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::AwaitFutureOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::CreateAsyncTaskOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::BuildReturnPtrPlaceholderOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::DerefReturnPtrPlaceholderOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::WorkFunctionReturnOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), - converter); + patterns.add< + // concretelang::TypeConvertingReinstantiationPattern< + // mlir::func::ReturnOp>, + concretelang::TypeConvertingReinstantiationPattern, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::MakeReadyFutureOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::AwaitFutureOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::CreateAsyncTaskOp, true>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::BuildReturnPtrPlaceholderOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::DerefReturnPtrPlaceholderOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::WorkFunctionReturnOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), + converter); //--------------------------------------------------------- Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) diff --git a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 5e7727909..79172bec0 100644 --- a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -16,8 +16,8 @@ #include "concretelang/Conversion/FHEToTFHEScalar/Pass.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Tools.h" +#include "concretelang/Conversion/Utils/Dialects/SCF.h" #include "concretelang/Conversion/Utils/FuncConstOpConversion.h" -#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" @@ -110,16 +110,17 @@ namespace lowering { /// A pattern rewriter superclass used by most op rewriters during the /// conversion. template -struct ScalarOpPattern : public mlir::OpRewritePattern { +struct ScalarOpPattern : public mlir::OpConversionPattern { - ScalarOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit) {} + ScalarOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpConversionPattern(converter, context, benefit) {} /// Writes the encoding of a plaintext of arbitrary precision using shift. mlir::Value writePlaintextShiftEncoding(mlir::Location location, mlir::Value rawPlaintext, int64_t encryptedWidth, - mlir::PatternRewriter &rewriter) const { + mlir::ConversionPatternRewriter &rewriter) const { int64_t intShift = 64 - 1 - encryptedWidth; mlir::Value castedInt = rewriter.create( location, rewriter.getIntegerType(64), rawPlaintext); @@ -131,53 +132,25 @@ struct ScalarOpPattern : public mlir::OpRewritePattern { } }; -/// Rewriter for the `FHE::zero` operation. -struct ZeroEintOpPattern : public mlir::OpRewritePattern { - ZeroEintOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(FHE::ZeroEintOp op, - mlir::PatternRewriter &rewriter) const override { - mlir::Location location = op.getLoc(); - typing::TypeConverter converter; - TFHE::ZeroGLWEOp newOp = - rewriter.create(location, op.getType()); - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); - rewriter.replaceOp(op, {newOp.getResult()}); - return mlir::success(); - } -}; - /// Rewriter for the `FHE::add_eint_int` operation. struct AddEintIntOpPattern : public ScalarOpPattern { - AddEintIntOpPattern(mlir::MLIRContext *context, + AddEintIntOpPattern(mlir::TypeConverter &converter, + mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ScalarOpPattern(context, benefit) {} + : ScalarOpPattern(converter, context, benefit) {} mlir::LogicalResult - matchAndRewrite(FHE::AddEintIntOp op, - mlir::PatternRewriter &rewriter) const override { - mlir::Location location = op.getLoc(); - mlir::Value eintOperand = op.a(); - mlir::Value intOperand = op.b(); - + matchAndRewrite(FHE::AddEintIntOp op, FHE::AddEintIntOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { // Write the plaintext encoding mlir::Value encodedInt = writePlaintextShiftEncoding( - op.getLoc(), intOperand, - eintOperand.getType().cast().getWidth(), - rewriter); + op.getLoc(), adaptor.b(), + op.getType().cast().getWidth(), rewriter); // Write the new op - auto newOp = rewriter.create(location, op.getType(), - eintOperand, encodedInt); - typing::TypeConverter converter; - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); - - rewriter.replaceOp(op, {newOp.getResult()}); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.a(), + encodedInt); return mlir::success(); } @@ -185,13 +158,14 @@ struct AddEintIntOpPattern : public ScalarOpPattern { /// Rewriter for the `FHE::sub_eint_int` operation. struct SubEintIntOpPattern : public ScalarOpPattern { - SubEintIntOpPattern(mlir::MLIRContext *context, + SubEintIntOpPattern(mlir::TypeConverter &converter, + mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ScalarOpPattern(context, benefit) {} + : ScalarOpPattern(converter, context, benefit) {} mlir::LogicalResult - matchAndRewrite(FHE::SubEintIntOp op, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(FHE::SubEintIntOp op, FHE::SubEintIntOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); mlir::Value eintOperand = op.a(); mlir::Value intOperand = op.b(); @@ -213,15 +187,9 @@ struct SubEintIntOpPattern : public ScalarOpPattern { rewriter); // Write the new op - auto newOp = rewriter.create(location, op.getType(), - eintOperand, encodedInt); - typing::TypeConverter converter; - - // Convert the types - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); - - rewriter.replaceOp(op, {newOp.getResult()}); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.a(), + encodedInt); return mlir::success(); }; @@ -229,31 +197,24 @@ struct SubEintIntOpPattern : public ScalarOpPattern { /// Rewriter for the `FHE::sub_int_eint` operation. struct SubIntEintOpPattern : public ScalarOpPattern { - SubIntEintOpPattern(mlir::MLIRContext *context, + SubIntEintOpPattern(mlir::TypeConverter &converter, + mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ScalarOpPattern(context, benefit) {} + : ScalarOpPattern(converter, context, benefit) {} mlir::LogicalResult - matchAndRewrite(FHE::SubIntEintOp op, - mlir::PatternRewriter &rewriter) const override { - mlir::Location location = op.getLoc(); - mlir::Value intOperand = op.a(); - mlir::Value eintOperand = op.b(); - + matchAndRewrite(FHE::SubIntEintOp op, FHE::SubIntEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { // Write the plaintext encoding mlir::Value encodedInt = writePlaintextShiftEncoding( - op.getLoc(), intOperand, - eintOperand.getType().cast().getWidth(), + op.getLoc(), adaptor.a(), + op.b().getType().cast().getWidth(), rewriter); // Write the new op - auto newOp = rewriter.create(location, op.getType(), - encodedInt, eintOperand); - typing::TypeConverter converter; - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); - - rewriter.replaceOp(op, {newOp.getResult()}); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), encodedInt, + adaptor.b()); return mlir::success(); }; @@ -261,60 +222,53 @@ struct SubIntEintOpPattern : public ScalarOpPattern { /// Rewriter for the `FHE::sub_eint` operation. struct SubEintOpPattern : public ScalarOpPattern { - SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ScalarOpPattern(context, benefit) {} + SubEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ScalarOpPattern(converter, context, benefit) {} mlir::LogicalResult - matchAndRewrite(FHE::SubEintOp op, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(FHE::SubEintOp op, FHE::SubEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); - mlir::Value lhsOperand = op.a(); - mlir::Value rhsOperand = op.b(); + mlir::Value lhsOperand = adaptor.a(); + mlir::Value rhsOperand = adaptor.b(); // Write rhs negation auto negative = rewriter.create( location, rhsOperand.getType(), rhsOperand); - typing::TypeConverter converter; - concretelang::convertOperandAndResultTypes(rewriter, negative, - converter.getConversionLambda()); // Write new op. - auto newOp = rewriter.create( - location, op.getType(), lhsOperand, negative.getResult()); - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), lhsOperand, + negative.getResult()); - rewriter.replaceOp(op, {newOp.getResult()}); return mlir::success(); }; }; /// Rewriter for the `FHE::mul_eint_int` operation. struct MulEintIntOpPattern : public ScalarOpPattern { - MulEintIntOpPattern(mlir::MLIRContext *context, + MulEintIntOpPattern(mlir::TypeConverter &converter, + mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) - : ScalarOpPattern(context, benefit) {} + : ScalarOpPattern(converter, context, benefit) {} mlir::LogicalResult - matchAndRewrite(FHE::MulEintIntOp op, - mlir::PatternRewriter &rewriter) const override { + matchAndRewrite(FHE::MulEintIntOp op, FHE::MulEintIntOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location location = op.getLoc(); - mlir::Value eintOperand = op.a(); - mlir::Value intOperand = op.b(); + mlir::Value eintOperand = adaptor.a(); + mlir::Value intOperand = adaptor.b(); // Write the cleartext "encoding" mlir::Value castedCleartext = rewriter.create( location, rewriter.getIntegerType(64), intOperand); // Write the new op. - auto newOp = rewriter.create( - location, op.getType(), eintOperand, castedCleartext); - typing::TypeConverter converter; - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); - - rewriter.replaceOp(op, {newOp.getResult()}); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), eintOperand, + castedCleartext); return mlir::success(); } @@ -324,15 +278,17 @@ struct MulEintIntOpPattern : public ScalarOpPattern { struct ApplyLookupTableEintOpPattern : public ScalarOpPattern { ApplyLookupTableEintOpPattern( - mlir::MLIRContext *context, + mlir::TypeConverter &converter, mlir::MLIRContext *context, concretelang::ScalarLoweringParameters loweringParams, mlir::PatternBenefit benefit = 1) - : ScalarOpPattern(context, benefit), + : ScalarOpPattern(converter, context, + benefit), loweringParameters(loweringParams) {} mlir::LogicalResult matchAndRewrite(FHE::ApplyLookupTableEintOp op, - mlir::PatternRewriter &rewriter) const override { + FHE::ApplyLookupTableEintOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { size_t outputBits = op.getResult().getType().cast().getWidth(); @@ -350,16 +306,12 @@ struct ApplyLookupTableEintOpPattern // Insert keyswitch auto ksOp = rewriter.create( - op.getLoc(), op.a().getType(), op.a(), -1, -1); - typing::TypeConverter converter; - concretelang::convertOperandAndResultTypes(rewriter, ksOp, - converter.getConversionLambda()); + op.getLoc(), adaptor.a().getType(), adaptor.a(), -1, -1); // Insert bootstrap - auto bsOp = rewriter.replaceOpWithNewOp( - op, op.getType(), ksOp, newLut, -1, -1, -1, -1); - concretelang::convertOperandAndResultTypes(rewriter, bsOp, - converter.getConversionLambda()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), ksOp, newLut, -1, -1, + -1, -1); return mlir::success(); }; @@ -473,20 +425,20 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { // Patterns for the `FHE` dialect operations patterns.add< // |_ `FHE::zero_eint` - concretelang::GenericTypeAndOpConverterPattern, + concretelang::GenericOneToOneOpConversionPattern, // |_ `FHE::zero_tensor` - concretelang::GenericTypeAndOpConverterPattern, + concretelang::GenericOneToOneOpConversionPattern< + FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>, // |_ `FHE::neg_eint` - concretelang::GenericTypeAndOpConverterPattern, + concretelang::GenericOneToOneOpConversionPattern, // |_ `FHE::not` - concretelang::GenericTypeAndOpConverterPattern, + concretelang::GenericOneToOneOpConversionPattern, // |_ `FHE::add_eint` - concretelang::GenericTypeAndOpConverterPattern>( + concretelang::GenericOneToOneOpConversionPattern>( &getContext(), converter); // |_ `FHE::add_eint_int` patterns.add { // |_ `FHE::sub_eint` lowering::SubEintOpPattern, // |_ `FHE::mul_eint_int` - lowering::MulEintIntOpPattern>(&getContext()); + lowering::MulEintIntOpPattern>(converter, &getContext()); // |_ `FHE::apply_lookup_table` - patterns.add(&getContext(), - loweringParameters); + patterns.add( + converter, &getContext(), loweringParameters); // Patterns for boolean conversion ops patterns.add( @@ -508,14 +460,12 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { // Patterns for the relics of the `FHELinalg` dialect operations. // |_ `linalg::generic` turned to nested `scf::for` - patterns - .add>( - patterns.getContext(), converter); - patterns.add>( - &getContext(), converter); + patterns.add>(patterns.getContext(), converter); patterns.add< - RegionOpTypeConverterPattern>( + concretelang::TypeConvertingReinstantiationPattern< + mlir::tensor::GenerateOp, true>, + concretelang::TypeConvertingReinstantiationPattern>( &getContext(), converter); concretelang::populateWithTensorTypeConverterPatterns(patterns, target, converter); @@ -523,33 +473,46 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { // Patterns for `func` dialect operations. mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); - patterns - .add>( - patterns.getContext(), converter); + patterns.add>(patterns.getContext(), converter); + + concretelang::addDynamicallyLegalTypeOp(target, + converter); + concretelang::addDynamicallyLegalTypeOp(target, + converter); + concretelang::addDynamicallyLegalTypeOp(target, + converter); + patterns.add>( &getContext(), converter); + // Patterns for `bufferization` dialect operations. + patterns.add>(patterns.getContext(), + converter); + concretelang::addDynamicallyLegalTypeOp( + target, converter); + // Patterns for the `RT` dialect operations. - patterns - .add, - concretelang::GenericTypeConverterPattern, - concretelang::GenericTypeConverterPattern< - concretelang::RT::MakeReadyFutureOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::AwaitFutureOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::CreateAsyncTaskOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::BuildReturnPtrPlaceholderOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::DerefReturnPtrPlaceholderOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::WorkFunctionReturnOp>, - concretelang::GenericTypeConverterPattern< - concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), - converter); + patterns.add< + concretelang::TypeConvertingReinstantiationPattern, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::MakeReadyFutureOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::AwaitFutureOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::CreateAsyncTaskOp, true>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::BuildReturnPtrPlaceholderOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::DerefReturnPtrPlaceholderOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::WorkFunctionReturnOp>, + concretelang::TypeConvertingReinstantiationPattern< + concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), + converter); //--------------------------------------------------------- Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) diff --git a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index af869ffe6..2e875d5c7 100644 --- a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -16,6 +16,7 @@ #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" #include "concretelang/Support/Constants.h" +#include namespace TFHE = mlir::concretelang::TFHE; @@ -300,6 +301,12 @@ void TFHEGlobalParametrizationPass::runOnOperation() { // Add all patterns to convert TFHE types populateWithTFHEOpTypeConversionPatterns(patterns, target, converter); + + patterns.add>(&getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::bufferization::AllocTensorOp>(target, converter); + patterns.add>( &getContext(), converter); diff --git a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index e699feaca..e51d202ed 100644 --- a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -4,6 +4,7 @@ // for license information. #include +#include #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -231,6 +232,8 @@ void TFHEToConcretePass::runOnOperation() { patterns.add< mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern< + mlir::bufferization::AllocTensorOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::MakeReadyFutureOp>, mlir::concretelang::GenericTypeConverterPattern< @@ -270,6 +273,8 @@ void TFHEToConcretePass::runOnOperation() { target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::bufferization::AllocTensorOp>(target, converter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { diff --git a/compiler/lib/Conversion/Utils/CMakeLists.txt b/compiler/lib/Conversion/Utils/CMakeLists.txt new file mode 100644 index 000000000..7315b804d --- /dev/null +++ b/compiler/lib/Conversion/Utils/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialects) diff --git a/compiler/lib/Conversion/Utils/Dialects/CMakeLists.txt b/compiler/lib/Conversion/Utils/Dialects/CMakeLists.txt new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/compiler/lib/Conversion/Utils/Dialects/CMakeLists.txt @@ -0,0 +1 @@ + diff --git a/compiler/lib/Conversion/Utils/Dialects/SCF.cpp b/compiler/lib/Conversion/Utils/Dialects/SCF.cpp new file mode 100644 index 000000000..7886874a6 --- /dev/null +++ b/compiler/lib/Conversion/Utils/Dialects/SCF.cpp @@ -0,0 +1,40 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Conversion/Utils/Dialects/SCF.h" +#include "mlir/Transforms/RegionUtils.h" +#include + +namespace mlir { +namespace concretelang { +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern::matchAndRewrite( + scf::ForOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Create new for loop with empty body, but converted iter args + scf::ForOp newForOp = rewriter.replaceOpWithNewOp( + oldOp, adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), adaptor.getInitArgs(), + [&](OpBuilder &builder, Location loc, Value iv, ValueRange args) {}); + + // Move operations from old for op to new one + auto &newOperations = newForOp.getBody()->getOperations(); + mlir::Block *oldBody = oldOp.getBody(); + + newOperations.splice(newOperations.begin(), oldBody->getOperations(), + oldBody->begin(), oldBody->end()); + + // Remap iter args and IV + for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(), + newForOp.getBody()->getArguments())) { + replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair), + newForOp.getRegion()); + } + + return mlir::success(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp b/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp new file mode 100644 index 000000000..6c4c959a6 --- /dev/null +++ b/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp @@ -0,0 +1,107 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Conversion/Utils/Dialects/Tensor.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace concretelang { + +// +// Specializations for CollapseShapeOp +// + +// Specialization copying attributes not necessary, as the base +// template works correctly +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern:: + matchAndRewrite( + tensor::CollapseShapeOp oldOp, + mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::SmallVector resultTypes = convertResultTypes(oldOp); + rewriter.replaceOpWithNewOp( + oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(), + oldOp.getReassociation()); + + return mlir::success(); +} + +// +// Specializations for FromElementsOp +// +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern:: + matchAndRewrite( + tensor::FromElementsOp oldOp, + mlir::OpConversionPattern::OpAdaptor + adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Type resultType = convertResultType(oldOp); + rewriter.replaceOpWithNewOp( + oldOp, resultType, adaptor.getElements()); + + return mlir::success(); +} + +// +// Specializations for ExpandShapeOp +// + +// Specialization copying attributes not necessary, as the base +// template works correctly + +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern:: + matchAndRewrite( + tensor::ExpandShapeOp oldOp, + mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::SmallVector resultTypes = convertResultTypes(oldOp); + rewriter.replaceOpWithNewOp( + oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(), + oldOp.getReassociation()); + + return mlir::success(); +} + +template <> +mlir::LogicalResult +TypeConvertingReinstantiationPattern::matchAndRewrite( + tensor::GenerateOp oldOp, + mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::SmallVector resultTypes = convertResultTypes(oldOp); + + rewriter.setInsertionPointAfter(oldOp); + tensor::GenerateOp newGenerateOp = rewriter.create( + oldOp.getLoc(), resultTypes, adaptor.getOperands(), oldOp->getAttrs()); + + mlir::Block &oldBlock = oldOp.getBody().getBlocks().front(); + mlir::Block &newBlock = newGenerateOp.getBody().getBlocks().front(); + auto begin = oldBlock.begin(); + auto nOps = oldBlock.getOperations().size(); + + newBlock.getOperations().splice(newBlock.getOperations().begin(), + oldBlock.getOperations(), begin, + std::next(begin, nOps - 1)); + + for (auto argsPair : llvm::zip(oldOp.getRegion().getArguments(), + newGenerateOp.getRegion().getArguments())) { + replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair), + newGenerateOp.getRegion()); + } + + rewriter.replaceOp(oldOp, newGenerateOp.getResult()); + + return mlir::success(); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp b/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp index dfab45cf4..b2de10f3a 100644 --- a/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp +++ b/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Transforms/DialectConversion.h" #include #include +#include #include #include #include diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint.mlir index 05ead4426..c56585f24 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>, %arg1: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { - // CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) {MANP = 2 : ui3} : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> + // CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> // CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{_,_,_}{7}> %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir index d68f4d565..c162c2df4 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - // CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> + // CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> // CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{7}> %1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>) @@ -11,7 +11,7 @@ func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { // CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool { - // CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> + // CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> // CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{2}> %1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool return %1: !FHE.ebool