Files
concrete/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp
Andi Drebes 73fd6c5fe7 refactor(compiler): FHE to TFHE: Use OpConversionPattern for dialect conversion
Use `OpConversionPattern` instead of `OpRewritePattern` for operation
conversion during dialect conversion. This makes explicit and in-place
type conversions unnecessary, since `OpConversionPattern` already
properly converts operand types and provides them to the rewrite rule
through an operation adaptor.

The main contributions of this commit are the two class templates
`TypeConvertingReinstantiationPattern` and
`GenericOneToOneOpConversionPattern`.

The former allows for the definition of a simple replacement rule that
re-instantiates an operation after the types of its operands have been
converted. This is especially useful for type-polymorphic operations
during dialect conversion.

The latter allows for the definition of patterns, where one operation
needs to be replaced with a different operation after conversion of
its operands.

The default implementations for the class templates provide
conversions rules for operations that have a generic builder method
that takes the desired return type(s), the operands and (optionally) a
set of attributes. How attributes are discarded during a conversion
(either by omitting the builder argument or by passing an empty set of
attributes) can be defined through specialization of
`ReinstantiationAttributeDismissalStrategy`.

Custom replacement rules that deviate from the scheme above should be
implemented by specializing
`TypeConvertingReinstantiationPattern::matchAndRewrite()` and
`GenericOneToOneOpConversionPattern::matchAndRewrite()`.
2023-02-01 14:27:10 +01:00

108 lines
3.6 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Conversion/Utils/Dialects/Tensor.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace concretelang {
//
// Specializations for CollapseShapeOp
//
// Specialization copying attributes not necessary, as the base
// template works correctly
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::CollapseShapeOp, false>::
matchAndRewrite(
tensor::CollapseShapeOp oldOp,
mlir::OpConversionPattern<tensor::CollapseShapeOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(),
oldOp.getReassociation());
return mlir::success();
}
//
// Specializations for FromElementsOp
//
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp, false>::
matchAndRewrite(
tensor::FromElementsOp oldOp,
mlir::OpConversionPattern<mlir::tensor::FromElementsOp>::OpAdaptor
adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type resultType = convertResultType(oldOp);
rewriter.replaceOpWithNewOp<mlir::tensor::FromElementsOp>(
oldOp, resultType, adaptor.getElements());
return mlir::success();
}
//
// Specializations for ExpandShapeOp
//
// Specialization copying attributes not necessary, as the base
// template works correctly
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::ExpandShapeOp, false>::
matchAndRewrite(
tensor::ExpandShapeOp oldOp,
mlir::OpConversionPattern<tensor::ExpandShapeOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(),
oldOp.getReassociation());
return mlir::success();
}
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::GenerateOp, true>::matchAndRewrite(
tensor::GenerateOp oldOp,
mlir::OpConversionPattern<tensor::GenerateOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
rewriter.setInsertionPointAfter(oldOp);
tensor::GenerateOp newGenerateOp = rewriter.create<tensor::GenerateOp>(
oldOp.getLoc(), resultTypes, adaptor.getOperands(), oldOp->getAttrs());
mlir::Block &oldBlock = oldOp.getBody().getBlocks().front();
mlir::Block &newBlock = newGenerateOp.getBody().getBlocks().front();
auto begin = oldBlock.begin();
auto nOps = oldBlock.getOperations().size();
newBlock.getOperations().splice(newBlock.getOperations().begin(),
oldBlock.getOperations(), begin,
std::next(begin, nOps - 1));
for (auto argsPair : llvm::zip(oldOp.getRegion().getArguments(),
newGenerateOp.getRegion().getArguments())) {
replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair),
newGenerateOp.getRegion());
}
rewriter.replaceOp(oldOp, newGenerateOp.getResult());
return mlir::success();
}
} // namespace concretelang
} // namespace mlir