// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "concretelang/Conversion/Utils/Dialects/Tensor.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" namespace mlir { namespace concretelang { // // Specializations for CollapseShapeOp // // Specialization copying attributes not necessary, as the base // template works correctly template <> mlir::LogicalResult TypeConvertingReinstantiationPattern:: matchAndRewrite( tensor::CollapseShapeOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::SmallVector resultTypes = convertResultTypes(oldOp); rewriter.replaceOpWithNewOp( oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(), oldOp.getReassociation()); return mlir::success(); } // // Specializations for FromElementsOp // template <> mlir::LogicalResult TypeConvertingReinstantiationPattern:: matchAndRewrite( tensor::FromElementsOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::Type resultType = convertResultType(oldOp); rewriter.replaceOpWithNewOp( oldOp, resultType, adaptor.getElements()); return mlir::success(); } // // Specializations for ExpandShapeOp // // Specialization copying attributes not necessary, as the base // template works correctly template <> mlir::LogicalResult TypeConvertingReinstantiationPattern:: matchAndRewrite( tensor::ExpandShapeOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::SmallVector resultTypes = convertResultTypes(oldOp); rewriter.replaceOpWithNewOp( oldOp, mlir::TypeRange{resultTypes}, adaptor.getSrc(), oldOp.getReassociation()); return mlir::success(); } template <> mlir::LogicalResult TypeConvertingReinstantiationPattern::matchAndRewrite( tensor::GenerateOp oldOp, mlir::OpConversionPattern::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { mlir::SmallVector resultTypes = convertResultTypes(oldOp); rewriter.setInsertionPointAfter(oldOp); tensor::GenerateOp newGenerateOp = rewriter.create( oldOp.getLoc(), resultTypes, adaptor.getOperands(), oldOp->getAttrs()); mlir::Block &oldBlock = oldOp.getBody().getBlocks().front(); mlir::Block &newBlock = newGenerateOp.getBody().getBlocks().front(); auto begin = oldBlock.begin(); auto nOps = oldBlock.getOperations().size(); newBlock.getOperations().splice(newBlock.getOperations().begin(), oldBlock.getOperations(), begin, std::next(begin, nOps - 1)); for (auto argsPair : llvm::zip(oldOp.getRegion().getArguments(), newGenerateOp.getRegion().getArguments())) { replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair), newGenerateOp.getRegion()); } rewriter.replaceOp(oldOp, newGenerateOp.getResult()); return mlir::success(); } } // namespace concretelang } // namespace mlir