From c4686c363184105160d8533edd7198a21b289f62 Mon Sep 17 00:00:00 2001 From: Ayoub Benaissa Date: Fri, 11 Aug 2023 15:58:39 +0100 Subject: [PATCH] fix(compiler): lower fhe.zero to either scalar or tensor variant based on encoding When using crt encoding, some fhe.zero op results will be converted to tensors (crt encoded eint), so should be converted to tfhe.zero_tensor operations instead of tfhe.zero --- .../Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index abef451c7..3bcb127f0 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -915,6 +915,29 @@ struct InsertSliceOpPattern : public CrtOpPattern { }; }; +/// Zero op result can be a tensor after CRT encoding, and thus need to be +/// rewritten as a ZeroTensor op +struct ZeroOpPattern : public CrtOpPattern { + ZeroOpPattern(mlir::MLIRContext *context, + mlir::concretelang::CrtLoweringParameters params, + mlir::PatternBenefit benefit = 1) + : CrtOpPattern(context, params, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(FHE::ZeroEintOp op, FHE::ZeroEintOp::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + + mlir::TypeConverter *converter = this->getTypeConverter(); + auto glweOrTensorType = converter->convertType(op.getResult().getType()); + if (mlir::dyn_cast(glweOrTensorType)) { + rewriter.replaceOpWithNewOp(op, glweOrTensorType); + } else { + rewriter.replaceOpWithNewOp(op, glweOrTensorType); + } + return mlir::success(); + }; +}; + } // namespace lowering struct FHEToTFHECrtPass : public FHEToTFHECrtBase { @@ -983,10 +1006,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { converter); // Patterns for the `FHE` dialect operations + patterns.add(&getContext(), loweringParameters); patterns.add< - // |_ `FHE::zero_eint` - mlir::concretelang::GenericOneToOneOpConversionPattern< - FHE::ZeroEintOp, TFHE::ZeroGLWEOp>, // |_ `FHE::zero_tensor` mlir::concretelang::GenericOneToOneOpConversionPattern< FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>>(&getContext(),