From 2fbcd1a792b053eca643cdb976e0813310ac7803 Mon Sep 17 00:00:00 2001 From: aPere3 Date: Tue, 31 Jan 2023 10:37:53 +0100 Subject: [PATCH] fix(concrete-compiler): reassociation maps are incorrect in crt mode See #890. --- .../Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp | 46 +++++++++++++++++-- compiler/lib/ServerLib/DynamicRankCall.cpp | 8 +++- .../check_tests/BugReport/bug_report_890.mlir | 12 +++++ 3 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 compiler/tests/check_tests/BugReport/bug_report_890.mlir diff --git a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index c57ec5db6..8fc778e07 100644 --- a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -705,6 +705,42 @@ struct TensorFromElementsOpPattern } }; +// Generic template for tensor operations that have reassociation map +// attributes. +template +struct TensorReassociationOpPattern : public CrtOpPattern { + TensorReassociationOpPattern(mlir::MLIRContext *context, + concretelang::CrtLoweringParameters params, + mlir::PatternBenefit benefit = 1) + : CrtOpPattern(context, params, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + + mlir::TypeConverter *converter = this->getTypeConverter(); + + auto reassocVal = (inRank ? adaptor.src() : op.result()); + auto reassocTy = reassocVal.getType(); + auto newReassocType = converter->convertType(reassocTy); + + mlir::SmallVector oldReassocs = + op.getReassociationIndices(); + mlir::SmallVector newReassocs{oldReassocs}; + mlir::ReassociationIndices newReassocEnd; + newReassocEnd.push_back( + newReassocType.template cast().getRank() - 1); + newReassocs.push_back(newReassocEnd); + + auto newOp = rewriter.create( + op.getLoc(), converter->convertType(op.getResult().getType()), + adaptor.src(), newReassocs); + rewriter.replaceOp(op, {newOp}); + + return mlir::success(); + }; +}; + } // namespace lowering struct FHEToTFHECrtPass : public FHEToTFHECrtBase { @@ -833,10 +869,12 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { mlir::tensor::ExtractSliceOp>>(patterns.getContext(), converter); patterns.add>(patterns.getContext(), converter); - patterns.add>(patterns.getContext(), converter); - patterns.add>(patterns.getContext(), converter); + patterns.add>(patterns.getContext(), + loweringParameters); + patterns.add>(patterns.getContext(), + loweringParameters); patterns.add>(&getContext(), converter); diff --git a/compiler/lib/ServerLib/DynamicRankCall.cpp b/compiler/lib/ServerLib/DynamicRankCall.cpp index 540699859..faff4ee4f 100644 --- a/compiler/lib/ServerLib/DynamicRankCall.cpp +++ b/compiler/lib/ServerLib/DynamicRankCall.cpp @@ -18,8 +18,12 @@ namespace serverlib { /// Helper class template that yields an unsigned integer type given a /// size in bytes template struct int_type_of_size {}; -template <> struct int_type_of_size<4> { typedef uint32_t type; }; -template <> struct int_type_of_size<8> { typedef uint64_t type; }; +template <> struct int_type_of_size<4> { + typedef uint32_t type; +}; +template <> struct int_type_of_size<8> { + typedef uint64_t type; +}; /// Converts one function pointer into another // TODO: Not sure this is valid in all implementations / on all diff --git a/compiler/tests/check_tests/BugReport/bug_report_890.mlir b/compiler/tests/check_tests/BugReport/bug_report_890.mlir new file mode 100644 index 000000000..8690450bd --- /dev/null +++ b/compiler/tests/check_tests/BugReport/bug_report_890.mlir @@ -0,0 +1,12 @@ +// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s +func.func @main(%2: tensor<1x1x!FHE.eint<16>>) -> tensor<1x1x1x!FHE.eint<16>> { + %3 = tensor.expand_shape %2 [[0], [1, 2]] : tensor<1x1x!FHE.eint<16>> into tensor<1x1x1x!FHE.eint<16>> + return %3 : tensor<1x1x1x!FHE.eint<16>> +} + +func.func @main2(%2: tensor<1x1x1x!FHE.eint<16>>) -> tensor<1x1x!FHE.eint<16>> { + %3 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<1x1x1x!FHE.eint<16>> into tensor<1x1x!FHE.eint<16>> + return %3 : tensor<1x1x!FHE.eint<16>> +} + +