mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(compiler): ConcreteToBConcrete: Add patterns for batched keyswitch and bootstrap
This commit is contained in:
@@ -229,6 +229,51 @@ struct LowerKeySwitch : public mlir::OpRewritePattern<
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerBatchedKeySwitch
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::BatchedKeySwitchLweOp> {
|
||||
LowerBatchedKeySwitch(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::BatchedKeySwitchLweOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::BatchedKeySwitchLweOp bksOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
|
||||
mlir::concretelang::Concrete::LweCiphertextType outType =
|
||||
bksOp.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
|
||||
auto outDimAttr = rewriter.getI32IntegerAttr(outType.getDimension());
|
||||
auto inputType =
|
||||
bksOp.ciphertexts()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
|
||||
mlir::IntegerAttr inputDimAttr =
|
||||
rewriter.getI32IntegerAttr(inputType.getDimension());
|
||||
|
||||
mlir::Operation *bBatchedKeySwitchOp = rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::BConcrete::BatchedKeySwitchLweBufferOp>(
|
||||
bksOp, bksOp.getType(), bksOp.ciphertexts(), bksOp.levelAttr(),
|
||||
bksOp.baseLogAttr(), inputDimAttr, outDimAttr);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bBatchedKeySwitchOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerBootstrap : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::BootstrapLweOp> {
|
||||
LowerBootstrap(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
@@ -262,6 +307,51 @@ struct LowerBootstrap : public mlir::OpRewritePattern<
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerBatchedBootstrap
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::BatchedBootstrapLweOp> {
|
||||
LowerBatchedBootstrap(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::BatchedBootstrapLweOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::BatchedBootstrapLweOp bbsOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
|
||||
mlir::concretelang::Concrete::LweCiphertextType outType =
|
||||
bbsOp.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
|
||||
auto inputType =
|
||||
bbsOp.input_ciphertexts()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
|
||||
auto inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension());
|
||||
auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP());
|
||||
|
||||
mlir::Operation *bBatchedBootstrapOp = rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::BConcrete::BatchedBootstrapLweBufferOp>(
|
||||
bbsOp, bbsOp.getType(), bbsOp.input_ciphertexts(), bbsOp.lookup_table(),
|
||||
inputDimAttr, bbsOp.polySizeAttr(), bbsOp.levelAttr(),
|
||||
bbsOp.baseLogAttr(), bbsOp.glweDimensionAttr(), outputPrecisionAttr);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bBatchedBootstrapOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct AddPlaintextLweCiphertextOpPattern
|
||||
: public mlir::OpRewritePattern<Concrete::AddPlaintextLweCiphertextOp> {
|
||||
AddPlaintextLweCiphertextOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -926,7 +1016,8 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
// Add patterns to trivialy convert Concrete op to the equivalent
|
||||
// BConcrete op
|
||||
patterns.insert<
|
||||
LowerBootstrap, LowerKeySwitch,
|
||||
LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch,
|
||||
LowerBatchedKeySwitch,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
|
||||
mlir::concretelang::BConcrete::AddLweBuffersOp,
|
||||
BConcrete::AddCRTLweBuffersOp>,
|
||||
|
||||
Reference in New Issue
Block a user