feat(LowToCAPI): replace key alloc w getters from RT

This commit is contained in:
youben11
2021-08-30 09:27:42 +01:00
committed by Quentin Bourgerie
parent c6b1480cc6
commit 32d67726e2

View File

@@ -226,8 +226,7 @@ struct GlweFromTableOpPattern
matchAndRewrite(mlir::zamalang::LowLFHE::GlweFromTable op,
mlir::PatternRewriter &rewriter) const override {
LowLFHEToConcreteCAPITypeConverter typeConverter;
auto errType =
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
auto errType = mlir::IndexType::get(rewriter.getContext());
// Insert forward declaration of the alloc_glwe function
{
auto funcType = mlir::FunctionType::get(
@@ -270,8 +269,8 @@ struct GlweFromTableOpPattern
mlir::IntegerType::get(rewriter.getContext(), 64)},
{mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter, "foreign_plaintext_list_u64",
funcType)
if (insertForwardDeclaration(
op, rewriter, "runtime_foreign_plaintext_list_u64", funcType)
.failed()) {
return mlir::failure();
}
@@ -312,7 +311,8 @@ struct GlweFromTableOpPattern
return mlir::failure();
}
}
auto errOp = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// allocate two glwe to build accumulator
auto glweSizeOp =
rewriter.create<mlir::ConstantOp>(op.getLoc(), op->getAttr("k"));
@@ -351,7 +351,7 @@ struct GlweFromTableOpPattern
mlir::SmallVector<mlir::Value> ForeignPlaintextListOperands{
errOp, op->getOperand(0), sizeOp};
auto foreignPlaintextListOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "foreign_plaintext_list_u64",
op.getLoc(), "runtime_foreign_plaintext_list_u64",
mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
rewriter.getContext()),
ForeignPlaintextListOperands);
@@ -373,8 +373,6 @@ struct GlweFromTableOpPattern
};
};
// TODO:
// Get concrete key
struct LowLFHEBootstrapLweOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp> {
LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context,
@@ -385,30 +383,43 @@ struct LowLFHEBootstrapLweOpPattern
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::BootstrapLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto errType =
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
auto errType = mlir::IndexType::get(rewriter.getContext());
auto lweOperandType = op->getOperandTypes().front();
// Insert forward declaration of the allocate_bsk_key function
// {
// auto funcType = mlir::FunctionType::get(
// rewriter.getContext(),
// {
// errType,
// // level
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // baselog
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // glwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // lwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // polynomial size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// },
// {mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
// rewriter.getContext())});
// if (insertForwardDeclaration(op, rewriter,
// "allocate_lwe_bootstrap_key_u64",
// funcType)
// .failed()) {
// return mlir::failure();
// }
// }
// Insert forward declaration of the getBsk function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
// level
mlir::IntegerType::get(rewriter.getContext(), 32),
// baselog
mlir::IntegerType::get(rewriter.getContext(), 32),
// glwe size
mlir::IntegerType::get(rewriter.getContext(), 32),
// lwe size
mlir::IntegerType::get(rewriter.getContext(), 32),
// polynomial size
mlir::IntegerType::get(rewriter.getContext(), 32),
},
rewriter.getContext(), {},
{mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter,
"allocate_lwe_bootstrap_key_u64", funcType)
if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey",
funcType)
.failed()) {
return mlir::failure();
}
@@ -448,7 +459,8 @@ struct LowLFHEBootstrapLweOpPattern
}
}
auto errOp = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// allocate the result lwe ciphertext
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(), mlir::IntegerAttr::get(
@@ -458,36 +470,44 @@ struct LowLFHEBootstrapLweOpPattern
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_lwe_ciphertext_u64", lweOperandType, allocLweCtOperands);
// allocate bsk
auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
auto glweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
auto polySizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("polynomialSize").cast<mlir::IntegerAttr>().getInt()));
mlir::SmallVector<mlir::Value> allocBskOperands{
errOp, decompLevelCountOp, decompBaseLogOp,
glweSizeOp, lweSizeOp, polySizeOp};
auto allocateBskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_bootstrap_key_u64",
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
// auto glweSizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32), -1));
// auto polySizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("polynomialSize").cast<mlir::IntegerAttr>().getInt()));
// mlir::SmallVector<mlir::Value> allocBskOperands{
// errOp, decompLevelCountOp, decompBaseLogOp,
// glweSizeOp, lweSizeOp, polySizeOp};
// auto allocateBskOp = rewriter.create<mlir::CallOp>(
// op.getLoc(), "allocate_lwe_bootstrap_key_u64",
// mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
// rewriter.getContext()),
// allocBskOperands);
// get bsk
mlir::SmallVector<mlir::Value> getBskOperands{};
auto getBskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "getGlobalBootstrapKey",
mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
rewriter.getContext()),
allocBskOperands);
getBskOperands);
// bootstrap
mlir::SmallVector<mlir::Value> bootstrapOperands{
errOp, allocateBskOp.getResult(0), allocateLweCtOp.getResult(0),
errOp, getBskOp.getResult(0), allocateLweCtOp.getResult(0),
op->getOperand(0), op->getOperand(1)};
rewriter.create<mlir::CallOp>(op.getLoc(), "bootstrap_lwe_u64",
mlir::TypeRange({}), bootstrapOperands);
@@ -496,9 +516,6 @@ struct LowLFHEBootstrapLweOpPattern
};
};
// TODO:
// Parameterization
// Get concrete key
struct LowLFHEKeySwitchLweOpPattern
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp> {
LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context,
@@ -509,28 +526,41 @@ struct LowLFHEKeySwitchLweOpPattern
mlir::LogicalResult
matchAndRewrite(mlir::zamalang::LowLFHE::KeySwitchLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto errType =
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
auto errType = mlir::IndexType::get(rewriter.getContext());
auto lweOperandType = op->getOperandTypes().front();
// Insert forward declaration of the allocate_ksk_key function
// {
// auto funcType = mlir::FunctionType::get(
// rewriter.getContext(),
// {
// errType,
// // level
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // baselog
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // input lwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// // output lwe size
// mlir::IntegerType::get(rewriter.getContext(), 32),
// },
// {mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
// rewriter.getContext())});
// if (insertForwardDeclaration(op, rewriter,
// "allocate_lwe_keyswitch_key_u64",
// funcType)
// .failed()) {
// return mlir::failure();
// }
// }
// Insert forward declaration of the getKsk function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{
errType,
// level
mlir::IntegerType::get(rewriter.getContext(), 32),
// baselog
mlir::IntegerType::get(rewriter.getContext(), 32),
// input lwe size
mlir::IntegerType::get(rewriter.getContext(), 32),
// output lwe size
mlir::IntegerType::get(rewriter.getContext(), 32),
},
rewriter.getContext(), {},
{mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
rewriter.getContext())});
if (insertForwardDeclaration(op, rewriter,
"allocate_lwe_keyswitch_key_u64", funcType)
if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey",
funcType)
.failed()) {
return mlir::failure();
}
@@ -573,7 +603,8 @@ struct LowLFHEKeySwitchLweOpPattern
}
}
auto errOp = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
rewriter.getIndexAttr(0));
// allocate the result lwe ciphertext
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
@@ -584,37 +615,46 @@ struct LowLFHEKeySwitchLweOpPattern
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands);
// allocate ksk
auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
auto inputLweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("inputLweSize").cast<mlir::IntegerAttr>().getInt()));
auto outputLweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
mlir::SmallVector<mlir::Value> allockskOperands{
errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp,
outputLweSizeOp};
auto allocateKskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "allocate_lwe_keyswitch_key_u64",
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
// auto inputLweSizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("inputLweSize").cast<mlir::IntegerAttr>().getInt()));
// auto outputLweSizeOp = rewriter.create<mlir::ConstantOp>(
// op.getLoc(),
// mlir::IntegerAttr::get(
// mlir::IntegerType::get(rewriter.getContext(), 32),
// op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
// mlir::SmallVector<mlir::Value> allockskOperands{
// errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp,
// outputLweSizeOp};
// auto allocateKskOp = rewriter.create<mlir::CallOp>(
// op.getLoc(), "allocate_lwe_keyswitch_key_u64",
// mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
// rewriter.getContext()),
// allockskOperands);
// get ksk
mlir::SmallVector<mlir::Value> getkskOperands{};
auto getKskOp = rewriter.create<mlir::CallOp>(
op.getLoc(), "getGlobalKeyswitchKey",
mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
rewriter.getContext()),
allockskOperands);
// bootstrap
getkskOperands);
// keyswitch
mlir::SmallVector<mlir::Value> keyswitchOperands{
errOp, allocateKskOp.getResult(0), allocateLweCtOp.getResult(0),
errOp, getKskOp.getResult(0), allocateLweCtOp.getResult(0),
op->getOperand(0)};
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
mlir::TypeRange({}), keyswitchOperands);