diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index abef451c7..3bcb127f0 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -915,6 +915,29 @@ struct InsertSliceOpPattern : public CrtOpPattern { }; }; +/// Zero op result can be a tensor after CRT encoding, and thus need to be +/// rewritten as a ZeroTensor op +struct ZeroOpPattern : public CrtOpPattern { + ZeroOpPattern(mlir::MLIRContext *context, + mlir::concretelang::CrtLoweringParameters params, + mlir::PatternBenefit benefit = 1) + : CrtOpPattern(context, params, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(FHE::ZeroEintOp op, FHE::ZeroEintOp::Adaptor adaptor, + ::mlir::ConversionPatternRewriter &rewriter) const override { + + mlir::TypeConverter *converter = this->getTypeConverter(); + auto glweOrTensorType = converter->convertType(op.getResult().getType()); + if (mlir::dyn_cast(glweOrTensorType)) { + rewriter.replaceOpWithNewOp(op, glweOrTensorType); + } else { + rewriter.replaceOpWithNewOp(op, glweOrTensorType); + } + return mlir::success(); + }; +}; + } // namespace lowering struct FHEToTFHECrtPass : public FHEToTFHECrtBase { @@ -983,10 +1006,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { converter); // Patterns for the `FHE` dialect operations + patterns.add(&getContext(), loweringParameters); patterns.add< - // |_ `FHE::zero_eint` - mlir::concretelang::GenericOneToOneOpConversionPattern< - FHE::ZeroEintOp, TFHE::ZeroGLWEOp>, // |_ `FHE::zero_tensor` mlir::concretelang::GenericOneToOneOpConversionPattern< FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>>(&getContext(),