Files
concrete/compilers/concrete-compiler/compiler/lib/Conversion/Utils/Dialects/Tensor.cpp
Andi Drebes 3ad3dcb08f refactor(compiler): Use signature conversion for conversion of ops with nested blocks
The current scheme used by reinstantiating conversion patterns in
`lib/Conversion/Utils/Dialects` for operations with blocks is to
create a new operation with empty blocks, to move the operations from
the old blocks and then to replace any references to block
arguments. However, such in-place updates of the types of block
arguments leave conversion patterns for operations nested in the
blocks without the ability to determine the original types of values
from before the update.

This change uses proper signature conversion for block arguments, such
that the original types of block arguments with converted types is
preserved, while the new types are made available through the dialect
conversion infrastructure via the respective adaptors.
2024-04-08 15:50:48 +02:00

89 lines
2.7 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Conversion/Utils/Dialects/Tensor.h"
#include "concretelang/Conversion/Utils/Utils.h"
namespace mlir {
namespace concretelang {
//
// Specializations for CollapseShapeOp
//
// Specialization copying attributes not necessary, as the base
// template works correctly
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::CollapseShapeOp, false>::
matchAndRewrite(
tensor::CollapseShapeOp oldOp,
mlir::OpConversionPattern<tensor::CollapseShapeOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(),
oldOp.getReassociation());
return mlir::success();
}
//
// Specializations for FromElementsOp
//
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<mlir::tensor::FromElementsOp, false>::
matchAndRewrite(
tensor::FromElementsOp oldOp,
mlir::OpConversionPattern<mlir::tensor::FromElementsOp>::OpAdaptor
adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type resultType = convertResultType(oldOp);
rewriter.replaceOpWithNewOp<mlir::tensor::FromElementsOp>(
oldOp, resultType, adaptor.getElements());
return mlir::success();
}
//
// Specializations for ExpandShapeOp
//
// Specialization copying attributes not necessary, as the base
// template works correctly
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::ExpandShapeOp, false>::
matchAndRewrite(
tensor::ExpandShapeOp oldOp,
mlir::OpConversionPattern<tensor::ExpandShapeOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(),
oldOp.getReassociation());
return mlir::success();
}
template <>
mlir::LogicalResult
TypeConvertingReinstantiationPattern<tensor::GenerateOp, true>::matchAndRewrite(
tensor::GenerateOp oldOp,
mlir::OpConversionPattern<tensor::GenerateOp>::OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::SmallVector<mlir::Type> resultTypes = convertResultTypes(oldOp);
convertOpWithBlocks(oldOp, adaptor.getOperands(), resultTypes,
*getTypeConverter(), rewriter);
return mlir::success();
}
} // namespace concretelang
} // namespace mlir