feat(compiler): ConcreteToBConcrete: Add patterns for batched keyswitch and bootstrap

This commit is contained in:
Andi Drebes
2022-11-15 16:52:00 +01:00
parent d46db1bf69
commit 9f153d2129

View File

@@ -229,6 +229,51 @@ struct LowerKeySwitch : public mlir::OpRewritePattern<
};
};
struct LowerBatchedKeySwitch
: public mlir::OpRewritePattern<
mlir::concretelang::Concrete::BatchedKeySwitchLweOp> {
LowerBatchedKeySwitch(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<
mlir::concretelang::Concrete::BatchedKeySwitchLweOp>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::BatchedKeySwitchLweOp bksOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
mlir::concretelang::Concrete::LweCiphertextType outType =
bksOp.getType()
.cast<mlir::TensorType>()
.getElementType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto outDimAttr = rewriter.getI32IntegerAttr(outType.getDimension());
auto inputType =
bksOp.ciphertexts()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
mlir::IntegerAttr inputDimAttr =
rewriter.getI32IntegerAttr(inputType.getDimension());
mlir::Operation *bBatchedKeySwitchOp = rewriter.replaceOpWithNewOp<
mlir::concretelang::BConcrete::BatchedKeySwitchLweBufferOp>(
bksOp, bksOp.getType(), bksOp.ciphertexts(), bksOp.levelAttr(),
bksOp.baseLogAttr(), inputDimAttr, outDimAttr);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, bBatchedKeySwitchOp, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
return ::mlir::success();
};
};
struct LowerBootstrap : public mlir::OpRewritePattern<
mlir::concretelang::Concrete::BootstrapLweOp> {
LowerBootstrap(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
@@ -262,6 +307,51 @@ struct LowerBootstrap : public mlir::OpRewritePattern<
};
};
struct LowerBatchedBootstrap
: public mlir::OpRewritePattern<
mlir::concretelang::Concrete::BatchedBootstrapLweOp> {
LowerBatchedBootstrap(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<
mlir::concretelang::Concrete::BatchedBootstrapLweOp>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::BatchedBootstrapLweOp bbsOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
mlir::concretelang::Concrete::LweCiphertextType outType =
bbsOp.getType()
.cast<mlir::TensorType>()
.getElementType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto inputType =
bbsOp.input_ciphertexts()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension());
auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP());
mlir::Operation *bBatchedBootstrapOp = rewriter.replaceOpWithNewOp<
mlir::concretelang::BConcrete::BatchedBootstrapLweBufferOp>(
bbsOp, bbsOp.getType(), bbsOp.input_ciphertexts(), bbsOp.lookup_table(),
inputDimAttr, bbsOp.polySizeAttr(), bbsOp.levelAttr(),
bbsOp.baseLogAttr(), bbsOp.glweDimensionAttr(), outputPrecisionAttr);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, bBatchedBootstrapOp, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
return ::mlir::success();
};
};
struct AddPlaintextLweCiphertextOpPattern
: public mlir::OpRewritePattern<Concrete::AddPlaintextLweCiphertextOp> {
AddPlaintextLweCiphertextOpPattern(::mlir::MLIRContext *context,
@@ -926,7 +1016,8 @@ void ConcreteToBConcretePass::runOnOperation() {
// Add patterns to trivialy convert Concrete op to the equivalent
// BConcrete op
patterns.insert<
LowerBootstrap, LowerKeySwitch,
LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch,
LowerBatchedKeySwitch,
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
mlir::concretelang::BConcrete::AddLweBuffersOp,
BConcrete::AddCRTLweBuffersOp>,