fix(concrete-compiler): reassociation maps are incorrect in crt mode

See #890.
This commit is contained in:
aPere3
2023-01-31 10:37:53 +01:00
committed by Alexandre Péré
parent c6a44e9091
commit 2fbcd1a792
3 changed files with 60 additions and 6 deletions

View File

@@ -705,6 +705,42 @@ struct TensorFromElementsOpPattern
}
};
// Generic template for tensor operations that have reassociation map
// attributes.
template <typename Op, bool inRank>
struct TensorReassociationOpPattern : public CrtOpPattern<Op> {
TensorReassociationOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<Op>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
::mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
auto reassocVal = (inRank ? adaptor.src() : op.result());
auto reassocTy = reassocVal.getType();
auto newReassocType = converter->convertType(reassocTy);
mlir::SmallVector<mlir::ReassociationIndices> oldReassocs =
op.getReassociationIndices();
mlir::SmallVector<mlir::ReassociationIndices> newReassocs{oldReassocs};
mlir::ReassociationIndices newReassocEnd;
newReassocEnd.push_back(
newReassocType.template cast<mlir::RankedTensorType>().getRank() - 1);
newReassocs.push_back(newReassocEnd);
auto newOp = rewriter.create<Op>(
op.getLoc(), converter->convertType(op.getResult().getType()),
adaptor.src(), newReassocs);
rewriter.replaceOp(op, {newOp});
return mlir::success();
};
};
} // namespace lowering
struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
@@ -833,10 +869,12 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
mlir::tensor::ExtractSliceOp>>(patterns.getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::InsertSliceOp>>(patterns.getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::CollapseShapeOp>>(patterns.getContext(), converter);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::ExpandShapeOp>>(patterns.getContext(), converter);
patterns.add<lowering::TensorReassociationOpPattern<
mlir::tensor::CollapseShapeOp, true>>(patterns.getContext(),
loweringParameters);
patterns.add<lowering::TensorReassociationOpPattern<
mlir::tensor::ExpandShapeOp, false>>(patterns.getContext(),
loweringParameters);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::GenerateOp, true>>(&getContext(), converter);

View File

@@ -18,8 +18,12 @@ namespace serverlib {
/// Helper class template that yields an unsigned integer type given a
/// size in bytes
template <std::size_t size> struct int_type_of_size {};
template <> struct int_type_of_size<4> { typedef uint32_t type; };
template <> struct int_type_of_size<8> { typedef uint64_t type; };
template <> struct int_type_of_size<4> {
typedef uint32_t type;
};
template <> struct int_type_of_size<8> {
typedef uint64_t type;
};
/// Converts one function pointer into another
// TODO: Not sure this is valid in all implementations / on all

View File

@@ -0,0 +1,12 @@
// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s
func.func @main(%2: tensor<1x1x!FHE.eint<16>>) -> tensor<1x1x1x!FHE.eint<16>> {
%3 = tensor.expand_shape %2 [[0], [1, 2]] : tensor<1x1x!FHE.eint<16>> into tensor<1x1x1x!FHE.eint<16>>
return %3 : tensor<1x1x1x!FHE.eint<16>>
}
func.func @main2(%2: tensor<1x1x1x!FHE.eint<16>>) -> tensor<1x1x!FHE.eint<16>> {
%3 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<1x1x1x!FHE.eint<16>> into tensor<1x1x!FHE.eint<16>>
return %3 : tensor<1x1x!FHE.eint<16>>
}