mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(concrete-compiler): fix bug in crt slice ops
This commit is contained in:
@@ -741,6 +741,82 @@ struct TensorReassociationOpPattern : public CrtOpPattern<Op> {
|
||||
};
|
||||
};
|
||||
|
||||
struct ExtractSliceOpPattern
|
||||
: public CrtOpPattern<mlir::tensor::ExtractSliceOp> {
|
||||
ExtractSliceOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<mlir::tensor::ExtractSliceOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::ExtractSliceOp op,
|
||||
mlir::tensor::ExtractSliceOp::Adaptor adaptor,
|
||||
::mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
mlir::SmallVector<mlir::Attribute> newStaticOffsets{
|
||||
op.static_offsets().template getAsRange<mlir::IntegerAttr>()};
|
||||
mlir::SmallVector<mlir::Attribute> newStaticSizes{
|
||||
op.static_sizes().template getAsRange<mlir::IntegerAttr>()};
|
||||
mlir::SmallVector<mlir::Attribute> newStaticStrides{
|
||||
op.static_strides().template getAsRange<mlir::IntegerAttr>()};
|
||||
newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
newStaticSizes.push_back(
|
||||
rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
|
||||
newStaticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
mlir::RankedTensorType newType =
|
||||
converter->convertType(op.getResult().getType())
|
||||
.template cast<mlir::RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
|
||||
op, newType, adaptor.source(), adaptor.getOffsets(), adaptor.getSizes(),
|
||||
adaptor.getStrides(), rewriter.getArrayAttr(newStaticOffsets),
|
||||
rewriter.getArrayAttr(newStaticSizes),
|
||||
rewriter.getArrayAttr(newStaticStrides));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct InsertSliceOpPattern : public CrtOpPattern<mlir::tensor::InsertSliceOp> {
|
||||
InsertSliceOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<mlir::tensor::InsertSliceOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::InsertSliceOp op,
|
||||
mlir::tensor::InsertSliceOp::Adaptor adaptor,
|
||||
::mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
mlir::SmallVector<mlir::Attribute> newStaticOffsets{
|
||||
op.static_offsets().template getAsRange<mlir::IntegerAttr>()};
|
||||
mlir::SmallVector<mlir::Attribute> newStaticSizes{
|
||||
op.static_sizes().template getAsRange<mlir::IntegerAttr>()};
|
||||
mlir::SmallVector<mlir::Attribute> newStaticStrides{
|
||||
op.static_strides().template getAsRange<mlir::IntegerAttr>()};
|
||||
newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
newStaticSizes.push_back(
|
||||
rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
|
||||
newStaticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
mlir::RankedTensorType newType =
|
||||
converter->convertType(op.getResult().getType())
|
||||
.template cast<mlir::RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
|
||||
op, newType, adaptor.source(), adaptor.dest(), adaptor.getOffsets(),
|
||||
adaptor.getSizes(), adaptor.getStrides(),
|
||||
rewriter.getArrayAttr(newStaticOffsets),
|
||||
rewriter.getArrayAttr(newStaticSizes),
|
||||
rewriter.getArrayAttr(newStaticStrides));
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace lowering
|
||||
|
||||
struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
@@ -865,16 +941,16 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
loweringParameters);
|
||||
patterns.add<lowering::TensorInsertOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::ExtractSliceOp>>(patterns.getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::InsertSliceOp>>(patterns.getContext(), converter);
|
||||
patterns.add<lowering::TensorReassociationOpPattern<
|
||||
mlir::tensor::CollapseShapeOp, true>>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::TensorReassociationOpPattern<
|
||||
mlir::tensor::ExpandShapeOp, false>>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::ExtractSliceOpPattern>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::InsertSliceOpPattern>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::GenerateOp, true>>(&getContext(), converter);
|
||||
|
||||
|
||||
10
compiler/tests/check_tests/BugReport/bug_report_858.mlir
Normal file
10
compiler/tests/check_tests/BugReport/bug_report_858.mlir
Normal file
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s
|
||||
func.func @main(%arg0: tensor<32x!FHE.eint<8>>) -> tensor<16x!FHE.eint<8>>{
|
||||
%0 = tensor.extract_slice %arg0[16] [16] [1] : tensor<32x!FHE.eint<8>> to tensor<16x!FHE.eint<8>>
|
||||
return %0 : tensor<16x!FHE.eint<8>>
|
||||
}
|
||||
|
||||
func.func @main2(%t0: tensor<2x10x!FHE.eint<6>>, %t1: tensor<2x2x!FHE.eint<6>>) -> tensor<2x10x!FHE.eint<6>> {
|
||||
%r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2x!FHE.eint<6>> into tensor<2x10x!FHE.eint<6>>
|
||||
return %r : tensor<2x10x!FHE.eint<6>>
|
||||
}
|
||||
Reference in New Issue
Block a user