mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
fix(compiler): lower fhe.zero to either scalar or tensor variant based on encoding
When using crt encoding, some fhe.zero op results will be converted to tensors (crt encoded eint), so should be converted to tfhe.zero_tensor operations instead of tfhe.zero
This commit is contained in:
committed by
Ayoub Benaissa
parent
471ebc080b
commit
c4686c3631
@@ -915,6 +915,29 @@ struct InsertSliceOpPattern : public CrtOpPattern<mlir::tensor::InsertSliceOp> {
|
||||
};
|
||||
};
|
||||
|
||||
/// Zero op result can be a tensor after CRT encoding, and thus need to be
|
||||
/// rewritten as a ZeroTensor op
|
||||
struct ZeroOpPattern : public CrtOpPattern<FHE::ZeroEintOp> {
|
||||
ZeroOpPattern(mlir::MLIRContext *context,
|
||||
mlir::concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::ZeroEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ZeroEintOp op, FHE::ZeroEintOp::Adaptor adaptor,
|
||||
::mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
auto glweOrTensorType = converter->convertType(op.getResult().getType());
|
||||
if (mlir::dyn_cast<mlir::TensorType>(glweOrTensorType)) {
|
||||
rewriter.replaceOpWithNewOp<TFHE::ZeroTensorGLWEOp>(op, glweOrTensorType);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<TFHE::ZeroGLWEOp>(op, glweOrTensorType);
|
||||
}
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace lowering
|
||||
|
||||
struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
@@ -983,10 +1006,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
converter);
|
||||
|
||||
// Patterns for the `FHE` dialect operations
|
||||
patterns.add<lowering::ZeroOpPattern>(&getContext(), loweringParameters);
|
||||
patterns.add<
|
||||
// |_ `FHE::zero_eint`
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
FHE::ZeroEintOp, TFHE::ZeroGLWEOp>,
|
||||
// |_ `FHE::zero_tensor`
|
||||
mlir::concretelang::GenericOneToOneOpConversionPattern<
|
||||
FHE::ZeroTensorOp, TFHE::ZeroTensorGLWEOp>>(&getContext(),
|
||||
|
||||
Reference in New Issue
Block a user