From 32d67726e2ecf3f22b06258bd73ff19709e95b37 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 30 Aug 2021 09:27:42 +0100 Subject: [PATCH] feat(LowToCAPI): replace key alloc w getters from RT --- .../LowLFHEToConcreteCAPI.cpp | 242 ++++++++++-------- 1 file changed, 141 insertions(+), 101 deletions(-) diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp index ae97f813a..5fb104fe5 100644 --- a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -226,8 +226,7 @@ struct GlweFromTableOpPattern matchAndRewrite(mlir::zamalang::LowLFHE::GlweFromTable op, mlir::PatternRewriter &rewriter) const override { LowLFHEToConcreteCAPITypeConverter typeConverter; - auto errType = - mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext())); + auto errType = mlir::IndexType::get(rewriter.getContext()); // Insert forward declaration of the alloc_glwe function { auto funcType = mlir::FunctionType::get( @@ -270,8 +269,8 @@ struct GlweFromTableOpPattern mlir::IntegerType::get(rewriter.getContext(), 64)}, {mlir::zamalang::LowLFHE::ForeignPlaintextListType::get( rewriter.getContext())}); - if (insertForwardDeclaration(op, rewriter, "foreign_plaintext_list_u64", - funcType) + if (insertForwardDeclaration( + op, rewriter, "runtime_foreign_plaintext_list_u64", funcType) .failed()) { return mlir::failure(); } @@ -312,7 +311,8 @@ struct GlweFromTableOpPattern return mlir::failure(); } } - auto errOp = rewriter.create(op.getLoc(), errType); + auto errOp = rewriter.create(op.getLoc(), + rewriter.getIndexAttr(0)); // allocate two glwe to build accumulator auto glweSizeOp = rewriter.create(op.getLoc(), op->getAttr("k")); @@ -351,7 +351,7 @@ struct GlweFromTableOpPattern mlir::SmallVector ForeignPlaintextListOperands{ errOp, op->getOperand(0), sizeOp}; auto foreignPlaintextListOp = rewriter.create( - op.getLoc(), "foreign_plaintext_list_u64", + op.getLoc(), "runtime_foreign_plaintext_list_u64", mlir::zamalang::LowLFHE::ForeignPlaintextListType::get( rewriter.getContext()), ForeignPlaintextListOperands); @@ -373,8 +373,6 @@ struct GlweFromTableOpPattern }; }; -// TODO: -// Get concrete key struct LowLFHEBootstrapLweOpPattern : public mlir::OpRewritePattern { LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context, @@ -385,30 +383,43 @@ struct LowLFHEBootstrapLweOpPattern mlir::LogicalResult matchAndRewrite(mlir::zamalang::LowLFHE::BootstrapLweOp op, mlir::PatternRewriter &rewriter) const override { - auto errType = - mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext())); + auto errType = mlir::IndexType::get(rewriter.getContext()); auto lweOperandType = op->getOperandTypes().front(); // Insert forward declaration of the allocate_bsk_key function + // { + // auto funcType = mlir::FunctionType::get( + // rewriter.getContext(), + // { + // errType, + // // level + // mlir::IntegerType::get(rewriter.getContext(), 32), + // // baselog + // mlir::IntegerType::get(rewriter.getContext(), 32), + // // glwe size + // mlir::IntegerType::get(rewriter.getContext(), 32), + // // lwe size + // mlir::IntegerType::get(rewriter.getContext(), 32), + // // polynomial size + // mlir::IntegerType::get(rewriter.getContext(), 32), + // }, + // {mlir::zamalang::LowLFHE::LweBootstrapKeyType::get( + // rewriter.getContext())}); + // if (insertForwardDeclaration(op, rewriter, + // "allocate_lwe_bootstrap_key_u64", + // funcType) + // .failed()) { + // return mlir::failure(); + // } + // } + + // Insert forward declaration of the getBsk function { auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - { - errType, - // level - mlir::IntegerType::get(rewriter.getContext(), 32), - // baselog - mlir::IntegerType::get(rewriter.getContext(), 32), - // glwe size - mlir::IntegerType::get(rewriter.getContext(), 32), - // lwe size - mlir::IntegerType::get(rewriter.getContext(), 32), - // polynomial size - mlir::IntegerType::get(rewriter.getContext(), 32), - }, + rewriter.getContext(), {}, {mlir::zamalang::LowLFHE::LweBootstrapKeyType::get( rewriter.getContext())}); - if (insertForwardDeclaration(op, rewriter, - "allocate_lwe_bootstrap_key_u64", funcType) + if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey", + funcType) .failed()) { return mlir::failure(); } @@ -448,7 +459,8 @@ struct LowLFHEBootstrapLweOpPattern } } - auto errOp = rewriter.create(op.getLoc(), errType); + auto errOp = rewriter.create(op.getLoc(), + rewriter.getIndexAttr(0)); // allocate the result lwe ciphertext auto lweSizeOp = rewriter.create( op.getLoc(), mlir::IntegerAttr::get( @@ -458,36 +470,44 @@ struct LowLFHEBootstrapLweOpPattern auto allocateLweCtOp = rewriter.replaceOpWithNewOp( op, "allocate_lwe_ciphertext_u64", lweOperandType, allocLweCtOperands); // allocate bsk - auto decompLevelCountOp = rewriter.create( - op.getLoc(), - mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), - op->getAttr("level").cast().getInt())); - auto decompBaseLogOp = rewriter.create( - op.getLoc(), - mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), - op->getAttr("baseLog").cast().getInt())); - auto glweSizeOp = rewriter.create( - op.getLoc(), - mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), -1)); - auto polySizeOp = rewriter.create( - op.getLoc(), - mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), - op->getAttr("polynomialSize").cast().getInt())); - mlir::SmallVector allocBskOperands{ - errOp, decompLevelCountOp, decompBaseLogOp, - glweSizeOp, lweSizeOp, polySizeOp}; - auto allocateBskOp = rewriter.create( - op.getLoc(), "allocate_lwe_bootstrap_key_u64", + // auto decompLevelCountOp = rewriter.create( + // op.getLoc(), + // mlir::IntegerAttr::get( + // mlir::IntegerType::get(rewriter.getContext(), 32), + // op->getAttr("level").cast().getInt())); + // auto decompBaseLogOp = rewriter.create( + // op.getLoc(), + // mlir::IntegerAttr::get( + // mlir::IntegerType::get(rewriter.getContext(), 32), + // op->getAttr("baseLog").cast().getInt())); + // auto glweSizeOp = rewriter.create( + // op.getLoc(), + // mlir::IntegerAttr::get( + // mlir::IntegerType::get(rewriter.getContext(), 32), -1)); + // auto polySizeOp = rewriter.create( + // op.getLoc(), + // mlir::IntegerAttr::get( + // mlir::IntegerType::get(rewriter.getContext(), 32), + // op->getAttr("polynomialSize").cast().getInt())); + // mlir::SmallVector allocBskOperands{ + // errOp, decompLevelCountOp, decompBaseLogOp, + // glweSizeOp, lweSizeOp, polySizeOp}; + // auto allocateBskOp = rewriter.create( + // op.getLoc(), "allocate_lwe_bootstrap_key_u64", + // mlir::zamalang::LowLFHE::LweBootstrapKeyType::get( + // rewriter.getContext()), + // allocBskOperands); + + // get bsk + mlir::SmallVector getBskOperands{}; + auto getBskOp = rewriter.create( + op.getLoc(), "getGlobalBootstrapKey", mlir::zamalang::LowLFHE::LweBootstrapKeyType::get( rewriter.getContext()), - allocBskOperands); + getBskOperands); // bootstrap mlir::SmallVector bootstrapOperands{ - errOp, allocateBskOp.getResult(0), allocateLweCtOp.getResult(0), + errOp, getBskOp.getResult(0), allocateLweCtOp.getResult(0), op->getOperand(0), op->getOperand(1)}; rewriter.create(op.getLoc(), "bootstrap_lwe_u64", mlir::TypeRange({}), bootstrapOperands); @@ -496,9 +516,6 @@ struct LowLFHEBootstrapLweOpPattern }; }; -// TODO: -// Parameterization -// Get concrete key struct LowLFHEKeySwitchLweOpPattern : public mlir::OpRewritePattern { LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context, @@ -509,28 +526,41 @@ struct LowLFHEKeySwitchLweOpPattern mlir::LogicalResult matchAndRewrite(mlir::zamalang::LowLFHE::KeySwitchLweOp op, mlir::PatternRewriter &rewriter) const override { - auto errType = - mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext())); + auto errType = mlir::IndexType::get(rewriter.getContext()); auto lweOperandType = op->getOperandTypes().front(); // Insert forward declaration of the allocate_ksk_key function + // { + // auto funcType = mlir::FunctionType::get( + // rewriter.getContext(), + // { + // errType, + // // level + // mlir::IntegerType::get(rewriter.getContext(), 32), + // // baselog + // mlir::IntegerType::get(rewriter.getContext(), 32), + // // input lwe size + // mlir::IntegerType::get(rewriter.getContext(), 32), + // // output lwe size + // mlir::IntegerType::get(rewriter.getContext(), 32), + // }, + // {mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get( + // rewriter.getContext())}); + // if (insertForwardDeclaration(op, rewriter, + // "allocate_lwe_keyswitch_key_u64", + // funcType) + // .failed()) { + // return mlir::failure(); + // } + // } + + // Insert forward declaration of the getKsk function { auto funcType = mlir::FunctionType::get( - rewriter.getContext(), - { - errType, - // level - mlir::IntegerType::get(rewriter.getContext(), 32), - // baselog - mlir::IntegerType::get(rewriter.getContext(), 32), - // input lwe size - mlir::IntegerType::get(rewriter.getContext(), 32), - // output lwe size - mlir::IntegerType::get(rewriter.getContext(), 32), - }, + rewriter.getContext(), {}, {mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get( rewriter.getContext())}); - if (insertForwardDeclaration(op, rewriter, - "allocate_lwe_keyswitch_key_u64", funcType) + if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey", + funcType) .failed()) { return mlir::failure(); } @@ -573,7 +603,8 @@ struct LowLFHEKeySwitchLweOpPattern } } - auto errOp = rewriter.create(op.getLoc(), errType); + auto errOp = rewriter.create(op.getLoc(), + rewriter.getIndexAttr(0)); // allocate the result lwe ciphertext auto lweSizeOp = rewriter.create( op.getLoc(), @@ -584,37 +615,46 @@ struct LowLFHEKeySwitchLweOpPattern auto allocateLweCtOp = rewriter.replaceOpWithNewOp( op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands); // allocate ksk - auto decompLevelCountOp = rewriter.create( - op.getLoc(), - mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), - op->getAttr("level").cast().getInt())); - auto decompBaseLogOp = rewriter.create( - op.getLoc(), - mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), - op->getAttr("baseLog").cast().getInt())); - auto inputLweSizeOp = rewriter.create( - op.getLoc(), - mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), - op->getAttr("inputLweSize").cast().getInt())); - auto outputLweSizeOp = rewriter.create( - op.getLoc(), - mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), - op->getAttr("outputLweSize").cast().getInt())); - mlir::SmallVector allockskOperands{ - errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp, - outputLweSizeOp}; - auto allocateKskOp = rewriter.create( - op.getLoc(), "allocate_lwe_keyswitch_key_u64", + // auto decompLevelCountOp = rewriter.create( + // op.getLoc(), + // mlir::IntegerAttr::get( + // mlir::IntegerType::get(rewriter.getContext(), 32), + // op->getAttr("level").cast().getInt())); + // auto decompBaseLogOp = rewriter.create( + // op.getLoc(), + // mlir::IntegerAttr::get( + // mlir::IntegerType::get(rewriter.getContext(), 32), + // op->getAttr("baseLog").cast().getInt())); + // auto inputLweSizeOp = rewriter.create( + // op.getLoc(), + // mlir::IntegerAttr::get( + // mlir::IntegerType::get(rewriter.getContext(), 32), + // op->getAttr("inputLweSize").cast().getInt())); + // auto outputLweSizeOp = rewriter.create( + // op.getLoc(), + // mlir::IntegerAttr::get( + // mlir::IntegerType::get(rewriter.getContext(), 32), + // op->getAttr("outputLweSize").cast().getInt())); + // mlir::SmallVector allockskOperands{ + // errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp, + // outputLweSizeOp}; + // auto allocateKskOp = rewriter.create( + // op.getLoc(), "allocate_lwe_keyswitch_key_u64", + // mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get( + // rewriter.getContext()), + // allockskOperands); + + // get ksk + mlir::SmallVector getkskOperands{}; + auto getKskOp = rewriter.create( + op.getLoc(), "getGlobalKeyswitchKey", mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get( rewriter.getContext()), - allockskOperands); - // bootstrap + getkskOperands); + + // keyswitch mlir::SmallVector keyswitchOperands{ - errOp, allocateKskOp.getResult(0), allocateLweCtOp.getResult(0), + errOp, getKskOp.getResult(0), allocateLweCtOp.getResult(0), op->getOperand(0)}; rewriter.create(op.getLoc(), "keyswitch_lwe_u64", mlir::TypeRange({}), keyswitchOperands);