diff --git a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index 8fc778e07..e038dfee3 100644 --- a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -741,6 +741,82 @@ struct TensorReassociationOpPattern : public CrtOpPattern { }; }; +struct ExtractSliceOpPattern + : public CrtOpPattern { + ExtractSliceOpPattern(mlir::MLIRContext *context, + concretelang::CrtLoweringParameters params, + mlir::PatternBenefit benefit = 1) + : CrtOpPattern(context, params, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::ExtractSliceOp op, + mlir::tensor::ExtractSliceOp::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + + mlir::TypeConverter *converter = this->getTypeConverter(); + + mlir::SmallVector newStaticOffsets{ + op.static_offsets().template getAsRange()}; + mlir::SmallVector newStaticSizes{ + op.static_sizes().template getAsRange()}; + mlir::SmallVector newStaticStrides{ + op.static_strides().template getAsRange()}; + newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0)); + newStaticSizes.push_back( + rewriter.getI64IntegerAttr(this->loweringParameters.nMods)); + newStaticStrides.push_back(rewriter.getI64IntegerAttr(1)); + + mlir::RankedTensorType newType = + converter->convertType(op.getResult().getType()) + .template cast(); + rewriter.replaceOpWithNewOp( + op, newType, adaptor.source(), adaptor.getOffsets(), adaptor.getSizes(), + adaptor.getStrides(), rewriter.getArrayAttr(newStaticOffsets), + rewriter.getArrayAttr(newStaticSizes), + rewriter.getArrayAttr(newStaticStrides)); + + return mlir::success(); + }; +}; + +struct InsertSliceOpPattern : public CrtOpPattern { + InsertSliceOpPattern(mlir::MLIRContext *context, + concretelang::CrtLoweringParameters params, + mlir::PatternBenefit benefit = 1) + : CrtOpPattern(context, params, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::tensor::InsertSliceOp op, + mlir::tensor::InsertSliceOp::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + + mlir::TypeConverter *converter = this->getTypeConverter(); + + mlir::SmallVector newStaticOffsets{ + op.static_offsets().template getAsRange()}; + mlir::SmallVector newStaticSizes{ + op.static_sizes().template getAsRange()}; + mlir::SmallVector newStaticStrides{ + op.static_strides().template getAsRange()}; + newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0)); + newStaticSizes.push_back( + rewriter.getI64IntegerAttr(this->loweringParameters.nMods)); + newStaticStrides.push_back(rewriter.getI64IntegerAttr(1)); + + mlir::RankedTensorType newType = + converter->convertType(op.getResult().getType()) + .template cast(); + rewriter.replaceOpWithNewOp( + op, newType, adaptor.source(), adaptor.dest(), adaptor.getOffsets(), + adaptor.getSizes(), adaptor.getStrides(), + rewriter.getArrayAttr(newStaticOffsets), + rewriter.getArrayAttr(newStaticSizes), + rewriter.getArrayAttr(newStaticStrides)); + + return mlir::success(); + }; +}; + } // namespace lowering struct FHEToTFHECrtPass : public FHEToTFHECrtBase { @@ -865,16 +941,16 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { loweringParameters); patterns.add(&getContext(), loweringParameters); - patterns.add>(patterns.getContext(), converter); - patterns.add>(patterns.getContext(), converter); patterns.add>(patterns.getContext(), loweringParameters); patterns.add>(patterns.getContext(), loweringParameters); + patterns.add(patterns.getContext(), + loweringParameters); + patterns.add(patterns.getContext(), + loweringParameters); patterns.add>(&getContext(), converter); diff --git a/compiler/tests/check_tests/BugReport/bug_report_858.mlir b/compiler/tests/check_tests/BugReport/bug_report_858.mlir new file mode 100644 index 000000000..b75f3da50 --- /dev/null +++ b/compiler/tests/check_tests/BugReport/bug_report_858.mlir @@ -0,0 +1,10 @@ +// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s +func.func @main(%arg0: tensor<32x!FHE.eint<8>>) -> tensor<16x!FHE.eint<8>>{ + %0 = tensor.extract_slice %arg0[16] [16] [1] : tensor<32x!FHE.eint<8>> to tensor<16x!FHE.eint<8>> + return %0 : tensor<16x!FHE.eint<8>> +} + +func.func @main2(%t0: tensor<2x10x!FHE.eint<6>>, %t1: tensor<2x2x!FHE.eint<6>>) -> tensor<2x10x!FHE.eint<6>> { + %r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2x!FHE.eint<6>> into tensor<2x10x!FHE.eint<6>> + return %r : tensor<2x10x!FHE.eint<6>> +}