From 9f153d21299d02514e48ac042fcd60ccf94a59ef Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 15 Nov 2022 16:52:00 +0100 Subject: [PATCH] feat(compiler): ConcreteToBConcrete: Add patterns for batched keyswitch and bootstrap --- .../ConcreteToBConcrete.cpp | 93 ++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index dfa4c24ea..2513255a3 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -229,6 +229,51 @@ struct LowerKeySwitch : public mlir::OpRewritePattern< }; }; +struct LowerBatchedKeySwitch + : public mlir::OpRewritePattern< + mlir::concretelang::Concrete::BatchedKeySwitchLweOp> { + LowerBatchedKeySwitch(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern< + mlir::concretelang::Concrete::BatchedKeySwitchLweOp>(context, + benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::Concrete::BatchedKeySwitchLweOp bksOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + + mlir::concretelang::Concrete::LweCiphertextType outType = + bksOp.getType() + .cast() + .getElementType() + .cast(); + + auto outDimAttr = rewriter.getI32IntegerAttr(outType.getDimension()); + auto inputType = + bksOp.ciphertexts() + .getType() + .cast() + .getElementType() + .cast(); + + mlir::IntegerAttr inputDimAttr = + rewriter.getI32IntegerAttr(inputType.getDimension()); + + mlir::Operation *bBatchedKeySwitchOp = rewriter.replaceOpWithNewOp< + mlir::concretelang::BConcrete::BatchedKeySwitchLweBufferOp>( + bksOp, bksOp.getType(), bksOp.ciphertexts(), bksOp.levelAttr(), + bksOp.baseLogAttr(), inputDimAttr, outDimAttr); + + mlir::concretelang::convertOperandAndResultTypes( + rewriter, bBatchedKeySwitchOp, [&](mlir::MLIRContext *, mlir::Type t) { + return converter.convertType(t); + }); + + return ::mlir::success(); + }; +}; + struct LowerBootstrap : public mlir::OpRewritePattern< mlir::concretelang::Concrete::BootstrapLweOp> { LowerBootstrap(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) @@ -262,6 +307,51 @@ struct LowerBootstrap : public mlir::OpRewritePattern< }; }; +struct LowerBatchedBootstrap + : public mlir::OpRewritePattern< + mlir::concretelang::Concrete::BatchedBootstrapLweOp> { + LowerBatchedBootstrap(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern< + mlir::concretelang::Concrete::BatchedBootstrapLweOp>(context, + benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::Concrete::BatchedBootstrapLweOp bbsOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + + mlir::concretelang::Concrete::LweCiphertextType outType = + bbsOp.getType() + .cast() + .getElementType() + .cast(); + + auto inputType = + bbsOp.input_ciphertexts() + .getType() + .cast() + .getElementType() + .cast(); + + auto inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension()); + auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP()); + + mlir::Operation *bBatchedBootstrapOp = rewriter.replaceOpWithNewOp< + mlir::concretelang::BConcrete::BatchedBootstrapLweBufferOp>( + bbsOp, bbsOp.getType(), bbsOp.input_ciphertexts(), bbsOp.lookup_table(), + inputDimAttr, bbsOp.polySizeAttr(), bbsOp.levelAttr(), + bbsOp.baseLogAttr(), bbsOp.glweDimensionAttr(), outputPrecisionAttr); + + mlir::concretelang::convertOperandAndResultTypes( + rewriter, bBatchedBootstrapOp, [&](mlir::MLIRContext *, mlir::Type t) { + return converter.convertType(t); + }); + + return ::mlir::success(); + }; +}; + struct AddPlaintextLweCiphertextOpPattern : public mlir::OpRewritePattern { AddPlaintextLweCiphertextOpPattern(::mlir::MLIRContext *context, @@ -926,7 +1016,8 @@ void ConcreteToBConcretePass::runOnOperation() { // Add patterns to trivialy convert Concrete op to the equivalent // BConcrete op patterns.insert< - LowerBootstrap, LowerKeySwitch, + LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch, + LowerBatchedKeySwitch, LowToBConcrete,