mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(concrete-compiler): reassociation maps are incorrect in crt mode
See #890.
This commit is contained in:
@@ -705,6 +705,42 @@ struct TensorFromElementsOpPattern
|
||||
}
|
||||
};
|
||||
|
||||
// Generic template for tensor operations that have reassociation map
|
||||
// attributes.
|
||||
template <typename Op, bool inRank>
|
||||
struct TensorReassociationOpPattern : public CrtOpPattern<Op> {
|
||||
TensorReassociationOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<Op>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
::mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
auto reassocVal = (inRank ? adaptor.src() : op.result());
|
||||
auto reassocTy = reassocVal.getType();
|
||||
auto newReassocType = converter->convertType(reassocTy);
|
||||
|
||||
mlir::SmallVector<mlir::ReassociationIndices> oldReassocs =
|
||||
op.getReassociationIndices();
|
||||
mlir::SmallVector<mlir::ReassociationIndices> newReassocs{oldReassocs};
|
||||
mlir::ReassociationIndices newReassocEnd;
|
||||
newReassocEnd.push_back(
|
||||
newReassocType.template cast<mlir::RankedTensorType>().getRank() - 1);
|
||||
newReassocs.push_back(newReassocEnd);
|
||||
|
||||
auto newOp = rewriter.create<Op>(
|
||||
op.getLoc(), converter->convertType(op.getResult().getType()),
|
||||
adaptor.src(), newReassocs);
|
||||
rewriter.replaceOp(op, {newOp});
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace lowering
|
||||
|
||||
struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
@@ -833,10 +869,12 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
mlir::tensor::ExtractSliceOp>>(patterns.getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::InsertSliceOp>>(patterns.getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::CollapseShapeOp>>(patterns.getContext(), converter);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::ExpandShapeOp>>(patterns.getContext(), converter);
|
||||
patterns.add<lowering::TensorReassociationOpPattern<
|
||||
mlir::tensor::CollapseShapeOp, true>>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::TensorReassociationOpPattern<
|
||||
mlir::tensor::ExpandShapeOp, false>>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::GenerateOp, true>>(&getContext(), converter);
|
||||
|
||||
|
||||
@@ -18,8 +18,12 @@ namespace serverlib {
|
||||
/// Helper class template that yields an unsigned integer type given a
|
||||
/// size in bytes
|
||||
template <std::size_t size> struct int_type_of_size {};
|
||||
template <> struct int_type_of_size<4> { typedef uint32_t type; };
|
||||
template <> struct int_type_of_size<8> { typedef uint64_t type; };
|
||||
template <> struct int_type_of_size<4> {
|
||||
typedef uint32_t type;
|
||||
};
|
||||
template <> struct int_type_of_size<8> {
|
||||
typedef uint64_t type;
|
||||
};
|
||||
|
||||
/// Converts one function pointer into another
|
||||
// TODO: Not sure this is valid in all implementations / on all
|
||||
|
||||
12
compiler/tests/check_tests/BugReport/bug_report_890.mlir
Normal file
12
compiler/tests/check_tests/BugReport/bug_report_890.mlir
Normal file
@@ -0,0 +1,12 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s
|
||||
func.func @main(%2: tensor<1x1x!FHE.eint<16>>) -> tensor<1x1x1x!FHE.eint<16>> {
|
||||
%3 = tensor.expand_shape %2 [[0], [1, 2]] : tensor<1x1x!FHE.eint<16>> into tensor<1x1x1x!FHE.eint<16>>
|
||||
return %3 : tensor<1x1x1x!FHE.eint<16>>
|
||||
}
|
||||
|
||||
func.func @main2(%2: tensor<1x1x1x!FHE.eint<16>>) -> tensor<1x1x!FHE.eint<16>> {
|
||||
%3 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<1x1x1x!FHE.eint<16>> into tensor<1x1x!FHE.eint<16>>
|
||||
return %3 : tensor<1x1x!FHE.eint<16>>
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user