// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete/blob/main/LICENSE.txt // for license information. #include "concretelang/Conversion/Utils/Dialects/Tensor.h" #include "concretelang/Conversion/Utils/Utils.h" namespace mlir { namespace concretelang { // // Specializations for CollapseShapeOp // // Specialization copying attributes not necessary, as the base // template works correctly template <> mlir::LogicalResult TypeConvertingReinstantiationPattern:: matchAndRewrite( tensor::CollapseShapeOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::SmallVector resultTypes = convertResultTypes(oldOp); rewriter.replaceOpWithNewOp( oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(), oldOp.getReassociation()); return mlir::success(); } // // Specializations for FromElementsOp // template <> mlir::LogicalResult TypeConvertingReinstantiationPattern:: matchAndRewrite( tensor::FromElementsOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::Type resultType = convertResultType(oldOp); rewriter.replaceOpWithNewOp( oldOp, resultType, adaptor.getElements()); return mlir::success(); } // // Specializations for ExpandShapeOp // // Specialization copying attributes not necessary, as the base // template works correctly template <> mlir::LogicalResult TypeConvertingReinstantiationPattern:: matchAndRewrite( tensor::ExpandShapeOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::SmallVector resultTypes = convertResultTypes(oldOp); rewriter.replaceOpWithNewOp( oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(), oldOp.getReassociation()); return mlir::success(); } template <> mlir::LogicalResult TypeConvertingReinstantiationPattern::matchAndRewrite( tensor::GenerateOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::SmallVector resultTypes = convertResultTypes(oldOp); convertOpWithBlocks(oldOp, adaptor.getOperands(), resultTypes, *getTypeConverter(), rewriter); return mlir::success(); } } // namespace concretelang } // namespace mlir