mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor(compiler): FHE to TFHE: Use OpConversionPattern for dialect conversion
Use `OpConversionPattern` instead of `OpRewritePattern` for operation conversion during dialect conversion. This makes explicit and in-place type conversions unnecessary, since `OpConversionPattern` already properly converts operand types and provides them to the rewrite rule through an operation adaptor. The main contributions of this commit are the two class templates `TypeConvertingReinstantiationPattern` and `GenericOneToOneOpConversionPattern`. The former allows for the definition of a simple replacement rule that re-instantiates an operation after the types of its operands have been converted. This is especially useful for type-polymorphic operations during dialect conversion. The latter allows for the definition of patterns, where one operation needs to be replaced with a different operation after conversion of its operands. The default implementations for the class templates provide conversions rules for operations that have a generic builder method that takes the desired return type(s), the operands and (optionally) a set of attributes. How attributes are discarded during a conversion (either by omitting the builder argument or by passing an empty set of attributes) can be defined through specialization of `ReinstantiationAttributeDismissalStrategy`. Custom replacement rules that deviate from the scheme above should be implemented by specializing `TypeConvertingReinstantiationPattern::matchAndRewrite()` and `GenericOneToOneOpConversionPattern::matchAndRewrite()`.
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_UTILS_DIALECTS_SCF_H_
|
||||
#define CONCRETELANG_CONVERSION_UTILS_DIALECTS_SCF_H_
|
||||
|
||||
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
//
|
||||
// Specializations for ForOp
|
||||
//
|
||||
|
||||
// Specialization copying attributes omitted
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
|
||||
scf::ForOp oldOp, mlir::OpConversionPattern<scf::ForOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,70 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_
|
||||
#define CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_
|
||||
|
||||
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
//
|
||||
// Specializations for CollapseShapeOp
|
||||
//
|
||||
|
||||
// Specialization copying attributes not necessary, as the base
|
||||
// template works correctly
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<tensor::CollapseShapeOp, false>::
|
||||
matchAndRewrite(
|
||||
tensor::CollapseShapeOp oldOp,
|
||||
mlir::OpConversionPattern<tensor::CollapseShapeOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
//
|
||||
// Specializations for FromElementsOp
|
||||
//
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp, false>::
|
||||
matchAndRewrite(
|
||||
tensor::FromElementsOp oldOp,
|
||||
mlir::OpConversionPattern<mlir::tensor::FromElementsOp>::OpAdaptor
|
||||
adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
//
|
||||
// Specializations for ExpandShapeOp
|
||||
//
|
||||
|
||||
// Specialization copying attributes not necessary, as the base
|
||||
// template works correctly
|
||||
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<tensor::ExpandShapeOp, false>::
|
||||
matchAndRewrite(
|
||||
tensor::ExpandShapeOp oldOp,
|
||||
mlir::OpConversionPattern<tensor::ExpandShapeOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
//
|
||||
// Specializations for GenerateOp
|
||||
//
|
||||
|
||||
// Specialization NOT copying attributes omitted
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<tensor::GenerateOp, true>::matchAndRewrite(
|
||||
tensor::GenerateOp oldOp,
|
||||
mlir::OpConversionPattern<tensor::GenerateOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // CONCRETELANG_CONVERSION_UTILS_DIALECTS_TENSOR_H_
|
||||
@@ -96,16 +96,6 @@ struct GenericTypeAndOpConverterPattern : public mlir::OpRewritePattern<OldOp> {
|
||||
private:
|
||||
mlir::TypeConverter &converter;
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
target.addDynamicallyLegalOp<Op>([&](Op op) {
|
||||
return typeConverter.isLegal(op->getOperandTypes()) &&
|
||||
typeConverter.isLegal(op->getResultTypes());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
26
compiler/include/concretelang/Conversion/Utils/Legality.h
Normal file
26
compiler/include/concretelang/Conversion/Utils/Legality.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_UTILS_LEGALITY_H_
|
||||
#define CONCRETELANG_CONVERSION_UTILS_LEGALITY_H_
|
||||
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
template <typename Op>
|
||||
void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
target.addDynamicallyLegalOp<Op>([&](Op op) {
|
||||
return typeConverter.isLegal(op->getOperandTypes()) &&
|
||||
typeConverter.isLegal(op->getResultTypes());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,216 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_UTILS_REINSTANTIATINGOPTYPECONVERSION_H_
|
||||
#define CONCRETELANG_CONVERSION_UTILS_REINSTANTIATINGOPTYPECONVERSION_H_
|
||||
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// Set of types defining how attributes should be handled when
|
||||
// invocating the build method of an operation upon reinstantiation
|
||||
struct ReinstantiationAttributeHandling {
|
||||
// Copy attributes
|
||||
struct copy {};
|
||||
|
||||
// Completely dismiss attributes by not passing a set of arguments
|
||||
// to the builder at all
|
||||
struct dismiss {};
|
||||
|
||||
// Dismiss attributes by passing an empty set of arguments to the
|
||||
// builder
|
||||
struct pass_empty_vector {};
|
||||
};
|
||||
|
||||
// Template defining how attributes should be dismissed when invoking
|
||||
// the build method of an operation upon reinstantiation. In the
|
||||
// default case, the argument for attributes is simply dismissed.
|
||||
template <typename T> struct ReinstantiationAttributeDismissalStrategy {
|
||||
typedef ReinstantiationAttributeHandling::dismiss strategy;
|
||||
};
|
||||
|
||||
// Template defining how attributes should be copied when invoking the
|
||||
// build method of an operation upon reinstantiation. In the default
|
||||
// case, the argument for attributes is forwarded to the build method.
|
||||
template <typename T> struct ReinstantiationAttributeCopyStrategy {
|
||||
typedef ReinstantiationAttributeHandling::copy strategy;
|
||||
};
|
||||
|
||||
namespace {
|
||||
// Class template that defines the attribute handling strategy for
|
||||
// either dismissal of attributes (if `copyAttrsSwitch` is `false`) or copying
|
||||
// attributes (if `copyAttrsSwitch` is `true`).
|
||||
template <typename T, bool copyAttrsSwitch> struct AttributeHandlingSwitch {};
|
||||
|
||||
template <typename T> struct AttributeHandlingSwitch<T, true> {
|
||||
typedef typename ReinstantiationAttributeCopyStrategy<T>::strategy strategy;
|
||||
};
|
||||
|
||||
template <typename T> struct AttributeHandlingSwitch<T, false> {
|
||||
typedef
|
||||
typename ReinstantiationAttributeDismissalStrategy<T>::strategy strategy;
|
||||
};
|
||||
|
||||
// Simple functor-like template invoking a rewriter with a variable
|
||||
// set of arguments and an op's attributes as the last argument.
|
||||
template <typename NewOpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpCopyAttrs {
|
||||
static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter,
|
||||
mlir::Operation *op, mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange operands) {
|
||||
return rewriter.replaceOpWithNewOp<NewOpTy>(op, resultTypes, operands,
|
||||
op->getAttrs());
|
||||
}
|
||||
};
|
||||
|
||||
// Simple functor-like template invoking a rewriter with a variable
|
||||
// set of arguments dismissing the attributes passed as the last
|
||||
// argument.
|
||||
template <typename NewOpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpDismissAttrs {
|
||||
static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter,
|
||||
mlir::Operation *op, mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange operands) {
|
||||
return rewriter.replaceOpWithNewOp<NewOpTy>(op, resultTypes, operands);
|
||||
}
|
||||
};
|
||||
|
||||
// Simple functor-like template invoking a rewriter with a variable
|
||||
// set of arguments dismissing the attributes by passing an empty
|
||||
// set of arguments to the builder.
|
||||
template <typename NewOpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpEmptyAttrs {
|
||||
static NewOpTy replace(mlir::ConversionPatternRewriter &rewriter,
|
||||
mlir::Operation *op, mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange operands) {
|
||||
llvm::SmallVector<mlir::NamedAttribute> attrs{};
|
||||
return rewriter.replaceOpWithNewOp<NewOpTy>(op, resultTypes, operands,
|
||||
attrs);
|
||||
}
|
||||
};
|
||||
|
||||
// Functor-like template that either forwards to
|
||||
// `ReplaceOpWithNewOpCopyAttrs` or `ReplaceOpWithNewOpDismissAttrs`
|
||||
// depending on the value of `copyAttrs`.
|
||||
template <typename copyAttrsSwitch, typename OpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpAttrSwitch {};
|
||||
|
||||
// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does copy
|
||||
// attributes.
|
||||
template <typename OpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpAttrSwitch<ReinstantiationAttributeHandling::copy,
|
||||
OpTy, Args...> {
|
||||
typedef ReplaceOpWithNewOpCopyAttrs<OpTy, Args...> instantiator;
|
||||
};
|
||||
|
||||
// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does NOT copy
|
||||
// attributes by not passing attributes to the builder at all.
|
||||
template <typename OpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpAttrSwitch<ReinstantiationAttributeHandling::dismiss,
|
||||
OpTy, Args...> {
|
||||
typedef ReplaceOpWithNewOpDismissAttrs<OpTy, Args...> instantiator;
|
||||
};
|
||||
|
||||
// Specialization of `ReplaceOpWithNewOpAttrSwitch` that does NOT copy
|
||||
// attributes by passing an empty set of attributes to the builder.
|
||||
template <typename OpTy, typename... Args>
|
||||
struct ReplaceOpWithNewOpAttrSwitch<
|
||||
ReinstantiationAttributeHandling::pass_empty_vector, OpTy, Args...> {
|
||||
typedef ReplaceOpWithNewOpEmptyAttrs<OpTy, Args...> instantiator;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename OldOp, typename NewOp, bool copyAttrs = false>
|
||||
struct GenericOneToOneOpConversionPatternBase
|
||||
: public mlir::OpConversionPattern<OldOp> {
|
||||
GenericOneToOneOpConversionPatternBase(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpConversionPattern<OldOp>(converter, context, benefit) {}
|
||||
|
||||
mlir::SmallVector<mlir::Type> convertResultTypes(OldOp oldOp) const {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
// Convert result types
|
||||
mlir::SmallVector<mlir::Type> resultTypes(oldOp->getNumResults());
|
||||
|
||||
for (unsigned i = 0; i < oldOp->getNumResults(); i++) {
|
||||
auto result = oldOp->getResult(i);
|
||||
resultTypes[i] = converter->convertType(result.getType());
|
||||
}
|
||||
|
||||
return resultTypes;
|
||||
}
|
||||
|
||||
mlir::Type convertResultType(OldOp oldOp) const {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
return converter->convertType(oldOp->getResult(0).getType());
|
||||
}
|
||||
};
|
||||
|
||||
// Conversion pattern that replaces an instance of an operation of the type
|
||||
// `OldOp` with an instance of the type `NewOp`, taking into account operands,
|
||||
// return types and possible copying attributes (iff copyAttrs is `true`).
|
||||
template <typename OldOp, typename NewOp, bool copyAttrs = false>
|
||||
struct GenericOneToOneOpConversionPattern
|
||||
: public GenericOneToOneOpConversionPatternBase<OldOp, NewOp, copyAttrs> {
|
||||
GenericOneToOneOpConversionPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: GenericOneToOneOpConversionPatternBase<OldOp, NewOp, copyAttrs>(
|
||||
context, converter, benefit) {}
|
||||
|
||||
virtual mlir::LogicalResult
|
||||
matchAndRewrite(OldOp oldOp,
|
||||
typename mlir::OpConversionPattern<OldOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::SmallVector<mlir::Type> resultTypes = this->convertResultTypes(oldOp);
|
||||
|
||||
ReplaceOpWithNewOpAttrSwitch<
|
||||
typename AttributeHandlingSwitch<NewOp, copyAttrs>::strategy,
|
||||
NewOp>::instantiator::replace(rewriter, oldOp,
|
||||
mlir::TypeRange{resultTypes},
|
||||
mlir::ValueRange{adaptor.getOperands()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// Conversion pattern that retrieves the converted operands of an
|
||||
// operation of the type `Op`, converts the types of the results of
|
||||
// the operation and re-instantiates the operation type with the
|
||||
// converted operands and result types.
|
||||
template <typename Op, bool copyAttrs = false>
|
||||
struct TypeConvertingReinstantiationPattern
|
||||
: public GenericOneToOneOpConversionPatternBase<Op, Op, copyAttrs> {
|
||||
TypeConvertingReinstantiationPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: GenericOneToOneOpConversionPatternBase<Op, Op, copyAttrs>(
|
||||
context, converter, benefit) {}
|
||||
// Simple forward that makes the method specializable out of class
|
||||
// directly for this class rather than for its base
|
||||
virtual mlir::LogicalResult
|
||||
matchAndRewrite(Op op,
|
||||
typename mlir::OpConversionPattern<Op>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::SmallVector<mlir::Type> resultTypes = this->convertResultTypes(op);
|
||||
|
||||
ReplaceOpWithNewOpAttrSwitch<
|
||||
typename AttributeHandlingSwitch<Op, copyAttrs>::strategy,
|
||||
Op>::instantiator::replace(rewriter, op, mlir::TypeRange{resultTypes},
|
||||
mlir::ValueRange{adaptor.getOperands()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -10,7 +10,8 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
#include "concretelang/Conversion/Utils/Dialects/Tensor.h"
|
||||
#include "concretelang/Conversion/Utils/Legality.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -20,37 +21,43 @@ populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns,
|
||||
mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
// ExtractOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::ExtractOp>>(
|
||||
patterns.add<TypeConvertingReinstantiationPattern<mlir::tensor::ExtractOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::ExtractOp>(target, typeConverter);
|
||||
|
||||
// ExtractSliceOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::ExtractSliceOp>>(
|
||||
patterns.add<
|
||||
TypeConvertingReinstantiationPattern<mlir::tensor::ExtractSliceOp, true>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::ExtractSliceOp>(target,
|
||||
typeConverter);
|
||||
|
||||
// InsertOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::InsertOp>>(
|
||||
patterns.add<TypeConvertingReinstantiationPattern<mlir::tensor::InsertOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::InsertOp>(target, typeConverter);
|
||||
// InsertSliceOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::InsertSliceOp>>(
|
||||
patterns.add<
|
||||
TypeConvertingReinstantiationPattern<mlir::tensor::InsertSliceOp, true>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::InsertSliceOp>(target, typeConverter);
|
||||
|
||||
// FromElementsOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::FromElementsOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
patterns
|
||||
.add<TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(target,
|
||||
typeConverter);
|
||||
// TensorCollapseShapeOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::CollapseShapeOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
patterns
|
||||
.add<TypeConvertingReinstantiationPattern<mlir::tensor::CollapseShapeOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::CollapseShapeOp>(target,
|
||||
typeConverter);
|
||||
// TensorExpandShapeOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::ExpandShapeOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
patterns
|
||||
.add<TypeConvertingReinstantiationPattern<mlir::tensor::ExpandShapeOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::ExpandShapeOp>(target, typeConverter);
|
||||
}
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -9,5 +9,13 @@ add_subdirectory(SDFGToStreamEmulator)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(LinalgExtras)
|
||||
add_subdirectory(ExtractSDFGOps)
|
||||
add_subdirectory(Utils)
|
||||
|
||||
add_mlir_library(ConcretelangConversion Tools.cpp LINK_LIBS PUBLIC MLIRIR)
|
||||
add_mlir_library(
|
||||
ConcretelangConversion
|
||||
Tools.cpp
|
||||
Utils/Dialects/SCF.cpp
|
||||
Utils/Dialects/Tensor.cpp
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR)
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
|
||||
@@ -10,13 +10,15 @@
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
|
||||
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
#include "concretelang/Conversion/Utils/ReinstantiatingOpTypeConversion.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/FHEToTFHECrt/Pass.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
@@ -112,7 +114,8 @@ namespace lowering {
|
||||
|
||||
/// A pattern rewriter superclass used by most op rewriters during the
|
||||
/// conversion.
|
||||
template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
|
||||
template <typename T>
|
||||
struct CrtOpPattern : public mlir::OpConversionPattern<T> {
|
||||
|
||||
/// The lowering parameters are bound to the op rewriter.
|
||||
concretelang::CrtLoweringParameters loweringParameters;
|
||||
@@ -120,8 +123,8 @@ template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
|
||||
CrtOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<T>(context, benefit),
|
||||
loweringParameters(params) {}
|
||||
: mlir::OpConversionPattern<T>(typeConverter, context, benefit),
|
||||
loweringParameters(params), typeConverter(params) {}
|
||||
|
||||
/// Writes an `scf::for` that loops over the crt dimension of one tensor and
|
||||
/// execute the input lambda to write the loop body. Returns the first result
|
||||
@@ -154,11 +157,6 @@ template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
|
||||
location, zeroConstantOp, crtSizeConstantOp, oneConstantOp, tensor,
|
||||
body);
|
||||
|
||||
// Convert the types of the new operation
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
return newOp.getResult(0);
|
||||
}
|
||||
|
||||
@@ -176,6 +174,9 @@ template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
|
||||
castedPlaintext, rewriter.getI64ArrayAttr(loweringParameters.mods),
|
||||
rewriter.getI64IntegerAttr(loweringParameters.modsProd));
|
||||
}
|
||||
|
||||
protected:
|
||||
typing::TypeConverter typeConverter;
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::add_eint_int` operation.
|
||||
@@ -187,17 +188,12 @@ struct AddEintIntOpPattern : public CrtOpPattern<FHE::AddEintIntOp> {
|
||||
: CrtOpPattern<FHE::AddEintIntOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::AddEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
matchAndRewrite(FHE::AddEintIntOp op, FHE::AddEintIntOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
intOperand.setType(converter.convertType(intOperand.getType()));
|
||||
eintOperand.setType(converter.convertType(eintOperand.getType()));
|
||||
mlir::Value eintOperand = adaptor.a();
|
||||
mlir::Value intOperand = adaptor.b();
|
||||
|
||||
// Write plaintext encoding
|
||||
mlir::Value encodedPlaintextTensor =
|
||||
@@ -205,7 +201,7 @@ struct AddEintIntOpPattern : public CrtOpPattern<FHE::AddEintIntOp> {
|
||||
|
||||
// Write add loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(eintOperand.getType())
|
||||
converter->convertType(eintOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeUnaryTensorLoop(
|
||||
@@ -239,17 +235,12 @@ struct SubIntEintOpPattern : public CrtOpPattern<FHE::SubIntEintOp> {
|
||||
: CrtOpPattern<FHE::SubIntEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubIntEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
matchAndRewrite(FHE::SubIntEintOp op, FHE::SubIntEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value intOperand = op.a();
|
||||
mlir::Value eintOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
intOperand.setType(converter.convertType(intOperand.getType()));
|
||||
eintOperand.setType(converter.convertType(eintOperand.getType()));
|
||||
mlir::Value intOperand = adaptor.a();
|
||||
mlir::Value eintOperand = adaptor.b();
|
||||
|
||||
// Write plaintext encoding
|
||||
mlir::Value encodedPlaintextTensor =
|
||||
@@ -257,7 +248,7 @@ struct SubIntEintOpPattern : public CrtOpPattern<FHE::SubIntEintOp> {
|
||||
|
||||
// Write add loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(eintOperand.getType())
|
||||
converter->convertType(eintOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeUnaryTensorLoop(
|
||||
@@ -291,17 +282,12 @@ struct SubEintIntOpPattern : public CrtOpPattern<FHE::SubEintIntOp> {
|
||||
: CrtOpPattern<FHE::SubEintIntOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
matchAndRewrite(FHE::SubEintIntOp op, FHE::SubEintIntOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
intOperand.setType(converter.convertType(intOperand.getType()));
|
||||
eintOperand.setType(converter.convertType(eintOperand.getType()));
|
||||
mlir::Value eintOperand = adaptor.a();
|
||||
mlir::Value intOperand = adaptor.b();
|
||||
|
||||
// Write plaintext negation
|
||||
mlir::Type intType = intOperand.getType();
|
||||
@@ -319,7 +305,7 @@ struct SubEintIntOpPattern : public CrtOpPattern<FHE::SubEintIntOp> {
|
||||
|
||||
// Write add loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(eintOperand.getType())
|
||||
converter->convertType(eintOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeUnaryTensorLoop(
|
||||
@@ -353,21 +339,16 @@ struct AddEintOpPattern : CrtOpPattern<FHE::AddEintOp> {
|
||||
: CrtOpPattern<FHE::AddEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::AddEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
matchAndRewrite(FHE::AddEintOp op, FHE::AddEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value lhsOperand = op.a();
|
||||
mlir::Value rhsOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
lhsOperand.setType(converter.convertType(lhsOperand.getType()));
|
||||
rhsOperand.setType(converter.convertType(rhsOperand.getType()));
|
||||
mlir::Value lhsOperand = adaptor.a();
|
||||
mlir::Value rhsOperand = adaptor.b();
|
||||
|
||||
// Write add loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(lhsOperand.getType())
|
||||
converter->convertType(lhsOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeUnaryTensorLoop(
|
||||
@@ -401,21 +382,16 @@ struct SubEintOpPattern : CrtOpPattern<FHE::SubEintOp> {
|
||||
: CrtOpPattern<FHE::SubEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
matchAndRewrite(FHE::SubEintOp op, FHE::SubEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value lhsOperand = op.a();
|
||||
mlir::Value rhsOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
lhsOperand.setType(converter.convertType(lhsOperand.getType()));
|
||||
rhsOperand.setType(converter.convertType(rhsOperand.getType()));
|
||||
mlir::Value lhsOperand = adaptor.a();
|
||||
mlir::Value rhsOperand = adaptor.b();
|
||||
|
||||
// Write sub loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(lhsOperand.getType())
|
||||
converter->convertType(lhsOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeUnaryTensorLoop(
|
||||
@@ -451,18 +427,14 @@ struct NegEintOpPattern : CrtOpPattern<FHE::NegEintOp> {
|
||||
: CrtOpPattern<FHE::NegEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::NegEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
matchAndRewrite(FHE::NegEintOp op, FHE::NegEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value operand = op.a();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
operand.setType(converter.convertType(operand.getType()));
|
||||
mlir::Value operand = adaptor.a();
|
||||
|
||||
// Write the loop nest.
|
||||
mlir::Type ciphertextScalarType = converter.convertType(operand.getType())
|
||||
mlir::Type ciphertextScalarType = converter->convertType(operand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value loopRes = writeUnaryTensorLoop(
|
||||
@@ -494,16 +466,12 @@ struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
|
||||
: CrtOpPattern<FHE::MulEintIntOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::MulEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
matchAndRewrite(FHE::MulEintIntOp op, FHE::MulEintIntOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
eintOperand.setType(converter.convertType(eintOperand.getType()));
|
||||
mlir::Value eintOperand = adaptor.a();
|
||||
mlir::Value intOperand = adaptor.b();
|
||||
|
||||
// Write cleartext "encoding"
|
||||
mlir::Value encodedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
|
||||
@@ -511,7 +479,7 @@ struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
|
||||
|
||||
// Write the loop nest.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(eintOperand.getType())
|
||||
converter->convertType(eintOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value loopRes = writeUnaryTensorLoop(
|
||||
@@ -545,9 +513,9 @@ struct ApplyLookupTableEintOpPattern
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
FHE::ApplyLookupTableEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
mlir::Value newLut =
|
||||
rewriter
|
||||
@@ -556,7 +524,7 @@ struct ApplyLookupTableEintOpPattern
|
||||
mlir::RankedTensorType::get(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.lutSize),
|
||||
rewriter.getI64Type()),
|
||||
op.lut(),
|
||||
adaptor.lut(),
|
||||
rewriter.getI64ArrayAttr(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.mods)),
|
||||
rewriter.getI64ArrayAttr(
|
||||
@@ -567,11 +535,9 @@ struct ApplyLookupTableEintOpPattern
|
||||
|
||||
// Replace the lut with an encoded / expanded one.
|
||||
auto wopPBS = rewriter.create<TFHE::WopPBSGLWEOp>(
|
||||
op.getLoc(), op.getType(), op.a(), newLut, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, rewriter.getI64ArrayAttr({}));
|
||||
op.getLoc(), converter->convertType(op.getType()), adaptor.a(), newLut,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, rewriter.getI64ArrayAttr({}));
|
||||
|
||||
concretelang::convertOperandAndResultTypes(rewriter, wopPBS,
|
||||
converter.getConversionLambda());
|
||||
rewriter.replaceOp(op, {wopPBS.getResult()});
|
||||
return ::mlir::success();
|
||||
};
|
||||
@@ -587,7 +553,9 @@ struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::ExtractOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::tensor::ExtractOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
if (!op.getTensor()
|
||||
.getType()
|
||||
@@ -601,7 +569,7 @@ struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
|
||||
.isa<TFHE::GLWECipherTextType>()) {
|
||||
return mlir::success();
|
||||
}
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets;
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes;
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides;
|
||||
@@ -617,11 +585,10 @@ struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
auto newOp = rewriter.create<mlir::tensor::ExtractSliceOp>(
|
||||
op.getLoc(),
|
||||
converter.convertType(op.getResult().getType())
|
||||
converter->convertType(op.getResult().getType())
|
||||
.cast<mlir::RankedTensorType>(),
|
||||
op.getTensor(), offsets, sizes, strides);
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
adaptor.getTensor(), offsets, sizes, strides);
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -637,8 +604,8 @@ struct TensorInsertOpPattern : public CrtOpPattern<mlir::tensor::InsertOp> {
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::InsertOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::tensor::InsertOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
if (!op.getDest()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
@@ -651,7 +618,7 @@ struct TensorInsertOpPattern : public CrtOpPattern<mlir::tensor::InsertOp> {
|
||||
.isa<TFHE::GLWECipherTextType>()) {
|
||||
return mlir::success();
|
||||
}
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets;
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes;
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides;
|
||||
@@ -666,9 +633,9 @@ struct TensorInsertOpPattern : public CrtOpPattern<mlir::tensor::InsertOp> {
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods));
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
auto newOp = rewriter.create<mlir::tensor::InsertSliceOp>(
|
||||
op.getLoc(), op.getScalar(), op.getDest(), offsets, sizes, strides);
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
op.getLoc(), adaptor.getScalar(), op.getDest(), offsets, sizes,
|
||||
strides);
|
||||
|
||||
rewriter.replaceOp(op, {newOp});
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -685,7 +652,10 @@ struct TensorFromElementsOpPattern
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::FromElementsOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::tensor::FromElementsOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
if (!op.getResult()
|
||||
.getType()
|
||||
.cast<mlir::RankedTensorType>()
|
||||
@@ -699,13 +669,11 @@ struct TensorFromElementsOpPattern
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
|
||||
// Create dest tensor allocation op
|
||||
mlir::Value outputTensor =
|
||||
rewriter.create<mlir::bufferization::AllocTensorOp>(
|
||||
op.getLoc(),
|
||||
converter.convertType(op.getResult().getType())
|
||||
converter->convertType(op.getResult().getType())
|
||||
.cast<mlir::RankedTensorType>(),
|
||||
mlir::ValueRange{});
|
||||
|
||||
@@ -723,15 +691,13 @@ struct TensorFromElementsOpPattern
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
offsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
}
|
||||
for (size_t insertionIndex = 0; insertionIndex < op.getElements().size();
|
||||
++insertionIndex) {
|
||||
for (size_t insertionIndex = 0;
|
||||
insertionIndex < adaptor.getElements().size(); ++insertionIndex) {
|
||||
offsets[0] = rewriter.getI64IntegerAttr(insertionIndex);
|
||||
mlir::tensor::InsertSliceOp insertOp =
|
||||
rewriter.create<mlir::tensor::InsertSliceOp>(
|
||||
op.getLoc(), op.getElements()[insertionIndex], outputTensor,
|
||||
op.getLoc(), adaptor.getElements()[insertionIndex], outputTensor,
|
||||
offsets, sizes, strides);
|
||||
concretelang::convertOperandAndResultTypes(
|
||||
rewriter, insertOp, converter.getConversionLambda());
|
||||
outputTensor = insertOp.getResult();
|
||||
}
|
||||
rewriter.replaceOp(op, {outputTensor});
|
||||
@@ -780,7 +746,11 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
op, converter);
|
||||
});
|
||||
target.addLegalOp<mlir::func::CallOp>();
|
||||
target.addLegalOp<mlir::bufferization::AllocTensorOp>();
|
||||
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::bufferization::AllocTensorOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(target,
|
||||
converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::ExtractSliceOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::InsertSliceOp>(
|
||||
@@ -812,15 +782,20 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
//---------------------------------------------------------- Adding patterns
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// Patterns for `bufferization` dialect operations.
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::bufferization::AllocTensorOp, true>>(patterns.getContext(),
|
||||
converter);
|
||||
|
||||
// Patterns for the `FHE` dialect operations
|
||||
patterns.add<
|
||||
// |_ `FHE::zero_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroEintOp,
|
||||
TFHE::ZeroGLWEOp>,
|
||||
concretelang::GenericOneToOneOpConversionPattern<FHE::ZeroEintOp,
|
||||
TFHE::ZeroGLWEOp>,
|
||||
// |_ `FHE::zero_tensor`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroTensorOp,
|
||||
TFHE::ZeroTensorGLWEOp>>(
|
||||
&getContext(), converter);
|
||||
concretelang::GenericOneToOneOpConversionPattern<
|
||||
FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>>(&getContext(),
|
||||
converter);
|
||||
// |_ `FHE::add_eint_int`
|
||||
patterns.add<lowering::AddEintIntOpPattern,
|
||||
// |_ `FHE::add_eint`
|
||||
@@ -841,37 +816,35 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
|
||||
// Patterns for the relics of the `FHELinalg` dialect operations.
|
||||
// |_ `linalg::generic` turned to nested `scf::for`
|
||||
patterns.add<concretelang::GenericTypeConverterPattern<mlir::scf::ForOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>>(
|
||||
patterns.add<
|
||||
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::ForOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<
|
||||
RegionOpTypeConverterPattern<mlir::scf::ForOp, typing::TypeConverter>>(
|
||||
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::YieldOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<
|
||||
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::ForOp>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<lowering::TensorExtractOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::TensorInsertOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<concretelang::GenericTypeConverterPattern<
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::ExtractSliceOp>>(patterns.getContext(), converter);
|
||||
patterns.add<
|
||||
concretelang::GenericTypeConverterPattern<mlir::tensor::InsertSliceOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<concretelang::GenericTypeConverterPattern<
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::InsertSliceOp>>(patterns.getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::CollapseShapeOp>>(patterns.getContext(), converter);
|
||||
patterns.add<
|
||||
concretelang::GenericTypeConverterPattern<mlir::tensor::ExpandShapeOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::ExpandShapeOp>>(patterns.getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::GenerateOp, true>>(&getContext(), converter);
|
||||
|
||||
// Patterns for `func` dialect operations.
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::func::ReturnOp>>(patterns.getContext(), converter);
|
||||
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
@@ -880,26 +853,27 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
loweringParameters);
|
||||
|
||||
// Patterns for the `RT` dialect operations.
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::MakeReadyFutureOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::AwaitFutureOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::CreateAsyncTaskOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::WorkFunctionReturnOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
patterns.add<
|
||||
// concretelang::TypeConvertingReinstantiationPattern<
|
||||
// mlir::func::ReturnOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::YieldOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::MakeReadyFutureOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::AwaitFutureOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::CreateAsyncTaskOp, true>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::WorkFunctionReturnOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
|
||||
//--------------------------------------------------------- Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
#include "concretelang/Conversion/FHEToTFHEScalar/Pass.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
@@ -110,16 +110,17 @@ namespace lowering {
|
||||
/// A pattern rewriter superclass used by most op rewriters during the
|
||||
/// conversion.
|
||||
template <typename T>
|
||||
struct ScalarOpPattern : public mlir::OpRewritePattern<T> {
|
||||
struct ScalarOpPattern : public mlir::OpConversionPattern<T> {
|
||||
|
||||
ScalarOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<T>(context, benefit) {}
|
||||
ScalarOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpConversionPattern<T>(converter, context, benefit) {}
|
||||
|
||||
/// Writes the encoding of a plaintext of arbitrary precision using shift.
|
||||
mlir::Value
|
||||
writePlaintextShiftEncoding(mlir::Location location, mlir::Value rawPlaintext,
|
||||
int64_t encryptedWidth,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
int64_t intShift = 64 - 1 - encryptedWidth;
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
location, rewriter.getIntegerType(64), rawPlaintext);
|
||||
@@ -131,53 +132,25 @@ struct ScalarOpPattern : public mlir::OpRewritePattern<T> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::zero` operation.
|
||||
struct ZeroEintOpPattern : public mlir::OpRewritePattern<FHE::ZeroEintOp> {
|
||||
ZeroEintOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<FHE::ZeroEintOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ZeroEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
typing::TypeConverter converter;
|
||||
TFHE::ZeroGLWEOp newOp =
|
||||
rewriter.create<TFHE::ZeroGLWEOp>(location, op.getType());
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::add_eint_int` operation.
|
||||
struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
|
||||
AddEintIntOpPattern(mlir::MLIRContext *context,
|
||||
AddEintIntOpPattern(mlir::TypeConverter &converter,
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::AddEintIntOp>(context, benefit) {}
|
||||
: ScalarOpPattern<FHE::AddEintIntOp>(converter, context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::AddEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
matchAndRewrite(FHE::AddEintIntOp op, FHE::AddEintIntOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), intOperand,
|
||||
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
|
||||
rewriter);
|
||||
op.getLoc(), adaptor.b(),
|
||||
op.getType().cast<FHE::EncryptedIntegerType>().getWidth(), rewriter);
|
||||
|
||||
// Write the new op
|
||||
auto newOp = rewriter.create<TFHE::AddGLWEIntOp>(location, op.getType(),
|
||||
eintOperand, encodedInt);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.a(),
|
||||
encodedInt);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -185,13 +158,14 @@ struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
|
||||
|
||||
/// Rewriter for the `FHE::sub_eint_int` operation.
|
||||
struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
|
||||
SubEintIntOpPattern(mlir::MLIRContext *context,
|
||||
SubEintIntOpPattern(mlir::TypeConverter &converter,
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::SubEintIntOp>(context, benefit) {}
|
||||
: ScalarOpPattern<FHE::SubEintIntOp>(converter, context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
matchAndRewrite(FHE::SubEintIntOp op, FHE::SubEintIntOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
@@ -213,15 +187,9 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
|
||||
rewriter);
|
||||
|
||||
// Write the new op
|
||||
auto newOp = rewriter.create<TFHE::AddGLWEIntOp>(location, op.getType(),
|
||||
eintOperand, encodedInt);
|
||||
typing::TypeConverter converter;
|
||||
|
||||
// Convert the types
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.a(),
|
||||
encodedInt);
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -229,31 +197,24 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
|
||||
|
||||
/// Rewriter for the `FHE::sub_int_eint` operation.
|
||||
struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
|
||||
SubIntEintOpPattern(mlir::MLIRContext *context,
|
||||
SubIntEintOpPattern(mlir::TypeConverter &converter,
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::SubIntEintOp>(context, benefit) {}
|
||||
: ScalarOpPattern<FHE::SubIntEintOp>(converter, context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubIntEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value intOperand = op.a();
|
||||
mlir::Value eintOperand = op.b();
|
||||
|
||||
matchAndRewrite(FHE::SubIntEintOp op, FHE::SubIntEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), intOperand,
|
||||
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
|
||||
op.getLoc(), adaptor.a(),
|
||||
op.b().getType().cast<FHE::EncryptedIntegerType>().getWidth(),
|
||||
rewriter);
|
||||
|
||||
// Write the new op
|
||||
auto newOp = rewriter.create<TFHE::SubGLWEIntOp>(location, op.getType(),
|
||||
encodedInt, eintOperand);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
rewriter.replaceOpWithNewOp<TFHE::SubGLWEIntOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), encodedInt,
|
||||
adaptor.b());
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -261,60 +222,53 @@ struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
|
||||
|
||||
/// Rewriter for the `FHE::sub_eint` operation.
|
||||
struct SubEintOpPattern : public ScalarOpPattern<FHE::SubEintOp> {
|
||||
SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::SubEintOp>(context, benefit) {}
|
||||
SubEintOpPattern(mlir::TypeConverter &converter, mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::SubEintOp>(converter, context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
matchAndRewrite(FHE::SubEintOp op, FHE::SubEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value lhsOperand = op.a();
|
||||
mlir::Value rhsOperand = op.b();
|
||||
mlir::Value lhsOperand = adaptor.a();
|
||||
mlir::Value rhsOperand = adaptor.b();
|
||||
|
||||
// Write rhs negation
|
||||
auto negative = rewriter.create<TFHE::NegGLWEOp>(
|
||||
location, rhsOperand.getType(), rhsOperand);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, negative,
|
||||
converter.getConversionLambda());
|
||||
|
||||
// Write new op.
|
||||
auto newOp = rewriter.create<TFHE::AddGLWEOp>(
|
||||
location, op.getType(), lhsOperand, negative.getResult());
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
rewriter.replaceOpWithNewOp<TFHE::AddGLWEOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), lhsOperand,
|
||||
negative.getResult());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::mul_eint_int` operation.
|
||||
struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
|
||||
MulEintIntOpPattern(mlir::MLIRContext *context,
|
||||
MulEintIntOpPattern(mlir::TypeConverter &converter,
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::MulEintIntOp>(context, benefit) {}
|
||||
: ScalarOpPattern<FHE::MulEintIntOp>(converter, context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::MulEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
matchAndRewrite(FHE::MulEintIntOp op, FHE::MulEintIntOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
mlir::Value eintOperand = adaptor.a();
|
||||
mlir::Value intOperand = adaptor.b();
|
||||
|
||||
// Write the cleartext "encoding"
|
||||
mlir::Value castedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
|
||||
location, rewriter.getIntegerType(64), intOperand);
|
||||
|
||||
// Write the new op.
|
||||
auto newOp = rewriter.create<TFHE::MulGLWEIntOp>(
|
||||
location, op.getType(), eintOperand, castedCleartext);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
rewriter.replaceOpWithNewOp<TFHE::MulGLWEIntOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), eintOperand,
|
||||
castedCleartext);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -324,15 +278,17 @@ struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
|
||||
struct ApplyLookupTableEintOpPattern
|
||||
: public ScalarOpPattern<FHE::ApplyLookupTableEintOp> {
|
||||
ApplyLookupTableEintOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter, mlir::MLIRContext *context,
|
||||
concretelang::ScalarLoweringParameters loweringParams,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::ApplyLookupTableEintOp>(context, benefit),
|
||||
: ScalarOpPattern<FHE::ApplyLookupTableEintOp>(converter, context,
|
||||
benefit),
|
||||
loweringParameters(loweringParams) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
FHE::ApplyLookupTableEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
size_t outputBits =
|
||||
op.getResult().getType().cast<FHE::EncryptedIntegerType>().getWidth();
|
||||
@@ -350,16 +306,12 @@ struct ApplyLookupTableEintOpPattern
|
||||
|
||||
// Insert keyswitch
|
||||
auto ksOp = rewriter.create<TFHE::KeySwitchGLWEOp>(
|
||||
op.getLoc(), op.a().getType(), op.a(), -1, -1);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, ksOp,
|
||||
converter.getConversionLambda());
|
||||
op.getLoc(), adaptor.a().getType(), adaptor.a(), -1, -1);
|
||||
|
||||
// Insert bootstrap
|
||||
auto bsOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
op, op.getType(), ksOp, newLut, -1, -1, -1, -1);
|
||||
concretelang::convertOperandAndResultTypes(rewriter, bsOp,
|
||||
converter.getConversionLambda());
|
||||
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), ksOp, newLut, -1, -1,
|
||||
-1, -1);
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -473,20 +425,20 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
// Patterns for the `FHE` dialect operations
|
||||
patterns.add<
|
||||
// |_ `FHE::zero_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroEintOp,
|
||||
TFHE::ZeroGLWEOp>,
|
||||
concretelang::GenericOneToOneOpConversionPattern<FHE::ZeroEintOp,
|
||||
TFHE::ZeroGLWEOp>,
|
||||
// |_ `FHE::zero_tensor`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroTensorOp,
|
||||
TFHE::ZeroTensorGLWEOp>,
|
||||
concretelang::GenericOneToOneOpConversionPattern<
|
||||
FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>,
|
||||
// |_ `FHE::neg_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::NegEintOp,
|
||||
TFHE::NegGLWEOp>,
|
||||
concretelang::GenericOneToOneOpConversionPattern<FHE::NegEintOp,
|
||||
TFHE::NegGLWEOp>,
|
||||
// |_ `FHE::not`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::BoolNotOp,
|
||||
TFHE::NegGLWEOp>,
|
||||
concretelang::GenericOneToOneOpConversionPattern<FHE::BoolNotOp,
|
||||
TFHE::NegGLWEOp>,
|
||||
// |_ `FHE::add_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::AddEintOp,
|
||||
TFHE::AddGLWEOp>>(
|
||||
concretelang::GenericOneToOneOpConversionPattern<FHE::AddEintOp,
|
||||
TFHE::AddGLWEOp>>(
|
||||
&getContext(), converter);
|
||||
// |_ `FHE::add_eint_int`
|
||||
patterns.add<lowering::AddEintIntOpPattern,
|
||||
@@ -497,10 +449,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
// |_ `FHE::sub_eint`
|
||||
lowering::SubEintOpPattern,
|
||||
// |_ `FHE::mul_eint_int`
|
||||
lowering::MulEintIntOpPattern>(&getContext());
|
||||
lowering::MulEintIntOpPattern>(converter, &getContext());
|
||||
// |_ `FHE::apply_lookup_table`
|
||||
patterns.add<lowering::ApplyLookupTableEintOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::ApplyLookupTableEintOpPattern>(
|
||||
converter, &getContext(), loweringParameters);
|
||||
|
||||
// Patterns for boolean conversion ops
|
||||
patterns.add<lowering::FromBoolOpPattern, lowering::ToBoolOpPattern>(
|
||||
@@ -508,14 +460,12 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
|
||||
// Patterns for the relics of the `FHELinalg` dialect operations.
|
||||
// |_ `linalg::generic` turned to nested `scf::for`
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::linalg::YieldOp>>(patterns.getContext(), converter);
|
||||
patterns.add<
|
||||
RegionOpTypeConverterPattern<mlir::scf::ForOp, typing::TypeConverter>>(
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::GenerateOp, true>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::ForOp>>(
|
||||
&getContext(), converter);
|
||||
concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
@@ -523,33 +473,46 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
// Patterns for `func` dialect operations.
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::func::ReturnOp>>(patterns.getContext(), converter);
|
||||
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(target,
|
||||
converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::scf::YieldOp>(target,
|
||||
converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
|
||||
converter);
|
||||
|
||||
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Patterns for `bufferization` dialect operations.
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::bufferization::AllocTensorOp, true>>(patterns.getContext(),
|
||||
converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::bufferization::AllocTensorOp>(
|
||||
target, converter);
|
||||
|
||||
// Patterns for the `RT` dialect operations.
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::MakeReadyFutureOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::AwaitFutureOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::CreateAsyncTaskOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::WorkFunctionReturnOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
patterns.add<
|
||||
concretelang::TypeConvertingReinstantiationPattern<mlir::scf::YieldOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::MakeReadyFutureOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::AwaitFutureOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::CreateAsyncTaskOp, true>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::WorkFunctionReturnOp>,
|
||||
concretelang::TypeConvertingReinstantiationPattern<
|
||||
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
|
||||
//--------------------------------------------------------- Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
|
||||
namespace TFHE = mlir::concretelang::TFHE;
|
||||
|
||||
@@ -300,6 +301,12 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
|
||||
// Add all patterns to convert TFHE types
|
||||
populateWithTFHEOpTypeConversionPatterns(patterns, target, converter);
|
||||
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::bufferization::AllocTensorOp>>(&getContext(), converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::bufferization::AllocTensorOp>(target, converter);
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::linalg::GenericOp, TFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
// for license information.
|
||||
|
||||
#include <iostream>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -231,6 +232,8 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::bufferization::AllocTensorOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
@@ -270,6 +273,8 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::linalg::YieldOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::bufferization::AllocTensorOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
|
||||
1
compiler/lib/Conversion/Utils/CMakeLists.txt
Normal file
1
compiler/lib/Conversion/Utils/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(Dialects)
|
||||
1
compiler/lib/Conversion/Utils/Dialects/CMakeLists.txt
Normal file
1
compiler/lib/Conversion/Utils/Dialects/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
40
compiler/lib/Conversion/Utils/Dialects/SCF.cpp
Normal file
40
compiler/lib/Conversion/Utils/Dialects/SCF.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
|
||||
scf::ForOp oldOp, mlir::OpConversionPattern<scf::ForOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
// Create new for loop with empty body, but converted iter args
|
||||
scf::ForOp newForOp = rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
oldOp, adaptor.getLowerBound(), adaptor.getUpperBound(),
|
||||
adaptor.getStep(), adaptor.getInitArgs(),
|
||||
[&](OpBuilder &builder, Location loc, Value iv, ValueRange args) {});
|
||||
|
||||
// Move operations from old for op to new one
|
||||
auto &newOperations = newForOp.getBody()->getOperations();
|
||||
mlir::Block *oldBody = oldOp.getBody();
|
||||
|
||||
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
|
||||
oldBody->begin(), oldBody->end());
|
||||
|
||||
// Remap iter args and IV
|
||||
for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(),
|
||||
newForOp.getBody()->getArguments())) {
|
||||
replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair),
|
||||
newForOp.getRegion());
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
107
compiler/lib/Conversion/Utils/Dialects/Tensor.cpp
Normal file
107
compiler/lib/Conversion/Utils/Dialects/Tensor.cpp
Normal file
@@ -0,0 +1,107 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Conversion/Utils/Dialects/Tensor.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
//
|
||||
// Specializations for CollapseShapeOp
|
||||
//
|
||||
|
||||
// Specialization copying attributes not necessary, as the base
|
||||
// template works correctly
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<tensor::CollapseShapeOp, false>::
|
||||
matchAndRewrite(
|
||||
tensor::CollapseShapeOp oldOp,
|
||||
mlir::OpConversionPattern<tensor::CollapseShapeOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
||||
oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(),
|
||||
oldOp.getReassociation());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//
|
||||
// Specializations for FromElementsOp
|
||||
//
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp, false>::
|
||||
matchAndRewrite(
|
||||
tensor::FromElementsOp oldOp,
|
||||
mlir::OpConversionPattern<mlir::tensor::FromElementsOp>::OpAdaptor
|
||||
adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
mlir::Type resultType = convertResultType(oldOp);
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::FromElementsOp>(
|
||||
oldOp, resultType, adaptor.getElements());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//
|
||||
// Specializations for ExpandShapeOp
|
||||
//
|
||||
|
||||
// Specialization copying attributes not necessary, as the base
|
||||
// template works correctly
|
||||
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<tensor::ExpandShapeOp, false>::
|
||||
matchAndRewrite(
|
||||
tensor::ExpandShapeOp oldOp,
|
||||
mlir::OpConversionPattern<tensor::ExpandShapeOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
||||
oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(),
|
||||
oldOp.getReassociation());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<tensor::GenerateOp, true>::matchAndRewrite(
|
||||
tensor::GenerateOp oldOp,
|
||||
mlir::OpConversionPattern<tensor::GenerateOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
|
||||
|
||||
rewriter.setInsertionPointAfter(oldOp);
|
||||
tensor::GenerateOp newGenerateOp = rewriter.create<tensor::GenerateOp>(
|
||||
oldOp.getLoc(), resultTypes, adaptor.getOperands(), oldOp->getAttrs());
|
||||
|
||||
mlir::Block &oldBlock = oldOp.getBody().getBlocks().front();
|
||||
mlir::Block &newBlock = newGenerateOp.getBody().getBlocks().front();
|
||||
auto begin = oldBlock.begin();
|
||||
auto nOps = oldBlock.getOperations().size();
|
||||
|
||||
newBlock.getOperations().splice(newBlock.getOperations().begin(),
|
||||
oldBlock.getOperations(), begin,
|
||||
std::next(begin, nOps - 1));
|
||||
|
||||
for (auto argsPair : llvm::zip(oldOp.getRegion().getArguments(),
|
||||
newGenerateOp.getRegion().getArguments())) {
|
||||
replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair),
|
||||
newGenerateOp.getRegion());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(oldOp, newGenerateOp.getResult());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <concretelang/Conversion/Utils/FuncConstOpConversion.h>
|
||||
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
|
||||
#include <concretelang/Conversion/Utils/Legality.h>
|
||||
#include <llvm/IR/Instructions.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Bufferization/Transforms/Bufferize.h>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>, %arg1: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) {MANP = 2 : ui3} : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
@@ -11,7 +11,7 @@ func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
|
||||
// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{2}>
|
||||
%1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
|
||||
Reference in New Issue
Block a user