fix(compiler): lower fhe.zero to either scalar or tensor variant based on encoding

When using crt encoding, some fhe.zero op results will be converted to tensors (crt encoded eint), so should be converted to tfhe.zero_tensor operations instead of tfhe.zero
This commit is contained in:
Ayoub Benaissa
2023-08-11 15:58:39 +01:00
committed by Ayoub Benaissa
parent 471ebc080b
commit c4686c3631

View File

@@ -915,6 +915,29 @@ struct InsertSliceOpPattern : public CrtOpPattern<mlir::tensor::InsertSliceOp> {
};
};
/// Zero op result can be a tensor after CRT encoding, and thus need to be
/// rewritten as a ZeroTensor op
struct ZeroOpPattern : public CrtOpPattern<FHE::ZeroEintOp> {
ZeroOpPattern(mlir::MLIRContext *context,
mlir::concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::ZeroEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::ZeroEintOp op, FHE::ZeroEintOp::Adaptor adaptor,
::mlir::ConversionPatternRewriter &rewriter) const override {
mlir::TypeConverter *converter = this->getTypeConverter();
auto glweOrTensorType = converter->convertType(op.getResult().getType());
if (mlir::dyn_cast<mlir::TensorType>(glweOrTensorType)) {
rewriter.replaceOpWithNewOp<TFHE::ZeroTensorGLWEOp>(op, glweOrTensorType);
} else {
rewriter.replaceOpWithNewOp<TFHE::ZeroGLWEOp>(op, glweOrTensorType);
}
return mlir::success();
};
};
} // namespace lowering
struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
@@ -983,10 +1006,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
converter);
// Patterns for the `FHE` dialect operations
patterns.add<lowering::ZeroOpPattern>(&getContext(), loweringParameters);
patterns.add<
// |_ `FHE::zero_eint`
mlir::concretelang::GenericOneToOneOpConversionPattern<
FHE::ZeroEintOp, TFHE::ZeroGLWEOp>,
// |_ `FHE::zero_tensor`
mlir::concretelang::GenericOneToOneOpConversionPattern<
FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>>(&getContext(),