fix(concrete-compiler): fix bug in crt slice ops

This commit is contained in:
aPere3
2023-01-31 14:20:08 +01:00
committed by Alexandre Péré
parent 2fbcd1a792
commit 002be243be
2 changed files with 90 additions and 4 deletions

View File

@@ -741,6 +741,82 @@ struct TensorReassociationOpPattern : public CrtOpPattern<Op> {
};
};
struct ExtractSliceOpPattern
: public CrtOpPattern<mlir::tensor::ExtractSliceOp> {
ExtractSliceOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<mlir::tensor::ExtractSliceOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::ExtractSliceOp op,
mlir::tensor::ExtractSliceOp::Adaptor adaptor,
::mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::SmallVector<mlir::Attribute> newStaticOffsets{
op.static_offsets().template getAsRange<mlir::IntegerAttr>()};
mlir::SmallVector<mlir::Attribute> newStaticSizes{
op.static_sizes().template getAsRange<mlir::IntegerAttr>()};
mlir::SmallVector<mlir::Attribute> newStaticStrides{
op.static_strides().template getAsRange<mlir::IntegerAttr>()};
newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0));
newStaticSizes.push_back(
rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
newStaticStrides.push_back(rewriter.getI64IntegerAttr(1));
mlir::RankedTensorType newType =
converter->convertType(op.getResult().getType())
.template cast<mlir::RankedTensorType>();
rewriter.replaceOpWithNewOp<mlir::tensor::ExtractSliceOp>(
op, newType, adaptor.source(), adaptor.getOffsets(), adaptor.getSizes(),
adaptor.getStrides(), rewriter.getArrayAttr(newStaticOffsets),
rewriter.getArrayAttr(newStaticSizes),
rewriter.getArrayAttr(newStaticStrides));
return mlir::success();
};
};
struct InsertSliceOpPattern : public CrtOpPattern<mlir::tensor::InsertSliceOp> {
InsertSliceOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<mlir::tensor::InsertSliceOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::InsertSliceOp op,
mlir::tensor::InsertSliceOp::Adaptor adaptor,
::mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
mlir::SmallVector<mlir::Attribute> newStaticOffsets{
op.static_offsets().template getAsRange<mlir::IntegerAttr>()};
mlir::SmallVector<mlir::Attribute> newStaticSizes{
op.static_sizes().template getAsRange<mlir::IntegerAttr>()};
mlir::SmallVector<mlir::Attribute> newStaticStrides{
op.static_strides().template getAsRange<mlir::IntegerAttr>()};
newStaticOffsets.push_back(rewriter.getI64IntegerAttr(0));
newStaticSizes.push_back(
rewriter.getI64IntegerAttr(this->loweringParameters.nMods));
newStaticStrides.push_back(rewriter.getI64IntegerAttr(1));
mlir::RankedTensorType newType =
converter->convertType(op.getResult().getType())
.template cast<mlir::RankedTensorType>();
rewriter.replaceOpWithNewOp<mlir::tensor::InsertSliceOp>(
op, newType, adaptor.source(), adaptor.dest(), adaptor.getOffsets(),
adaptor.getSizes(), adaptor.getStrides(),
rewriter.getArrayAttr(newStaticOffsets),
rewriter.getArrayAttr(newStaticSizes),
rewriter.getArrayAttr(newStaticStrides));
return mlir::success();
};
};
} // namespace lowering
struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
@@ -865,16 +941,16 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
loweringParameters);
patterns.add<lowering::TensorInsertOpPattern>(&getContext(),
loweringParameters);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::ExtractSliceOp>>(patterns.getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::InsertSliceOp>>(patterns.getContext(), converter);
patterns.add<lowering::TensorReassociationOpPattern<
mlir::tensor::CollapseShapeOp, true>>(patterns.getContext(),
loweringParameters);
patterns.add<lowering::TensorReassociationOpPattern<
mlir::tensor::ExpandShapeOp, false>>(patterns.getContext(),
loweringParameters);
patterns.add<lowering::ExtractSliceOpPattern>(patterns.getContext(),
loweringParameters);
patterns.add<lowering::InsertSliceOpPattern>(patterns.getContext(),
loweringParameters);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::GenerateOp, true>>(&getContext(), converter);

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s
func.func @main(%arg0: tensor<32x!FHE.eint<8>>) -> tensor<16x!FHE.eint<8>>{
%0 = tensor.extract_slice %arg0[16] [16] [1] : tensor<32x!FHE.eint<8>> to tensor<16x!FHE.eint<8>>
return %0 : tensor<16x!FHE.eint<8>>
}
func.func @main2(%t0: tensor<2x10x!FHE.eint<6>>, %t1: tensor<2x2x!FHE.eint<6>>) -> tensor<2x10x!FHE.eint<6>> {
%r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2x!FHE.eint<6>> into tensor<2x10x!FHE.eint<6>>
return %r : tensor<2x10x!FHE.eint<6>>
}