mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(LowToCAPI): replace key alloc w getters from RT
This commit is contained in:
committed by
Quentin Bourgerie
parent
c6b1480cc6
commit
32d67726e2
@@ -226,8 +226,7 @@ struct GlweFromTableOpPattern
|
||||
matchAndRewrite(mlir::zamalang::LowLFHE::GlweFromTable op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
LowLFHEToConcreteCAPITypeConverter typeConverter;
|
||||
auto errType =
|
||||
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
// Insert forward declaration of the alloc_glwe function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
@@ -270,8 +269,8 @@ struct GlweFromTableOpPattern
|
||||
mlir::IntegerType::get(rewriter.getContext(), 64)},
|
||||
{mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
|
||||
rewriter.getContext())});
|
||||
if (insertForwardDeclaration(op, rewriter, "foreign_plaintext_list_u64",
|
||||
funcType)
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "runtime_foreign_plaintext_list_u64", funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -312,7 +311,8 @@ struct GlweFromTableOpPattern
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
auto errOp = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
|
||||
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
|
||||
rewriter.getIndexAttr(0));
|
||||
// allocate two glwe to build accumulator
|
||||
auto glweSizeOp =
|
||||
rewriter.create<mlir::ConstantOp>(op.getLoc(), op->getAttr("k"));
|
||||
@@ -351,7 +351,7 @@ struct GlweFromTableOpPattern
|
||||
mlir::SmallVector<mlir::Value> ForeignPlaintextListOperands{
|
||||
errOp, op->getOperand(0), sizeOp};
|
||||
auto foreignPlaintextListOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "foreign_plaintext_list_u64",
|
||||
op.getLoc(), "runtime_foreign_plaintext_list_u64",
|
||||
mlir::zamalang::LowLFHE::ForeignPlaintextListType::get(
|
||||
rewriter.getContext()),
|
||||
ForeignPlaintextListOperands);
|
||||
@@ -373,8 +373,6 @@ struct GlweFromTableOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// TODO:
|
||||
// Get concrete key
|
||||
struct LowLFHEBootstrapLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::BootstrapLweOp> {
|
||||
LowLFHEBootstrapLweOpPattern(mlir::MLIRContext *context,
|
||||
@@ -385,30 +383,43 @@ struct LowLFHEBootstrapLweOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::LowLFHE::BootstrapLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto errType =
|
||||
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
auto lweOperandType = op->getOperandTypes().front();
|
||||
// Insert forward declaration of the allocate_bsk_key function
|
||||
// {
|
||||
// auto funcType = mlir::FunctionType::get(
|
||||
// rewriter.getContext(),
|
||||
// {
|
||||
// errType,
|
||||
// // level
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // baselog
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // glwe size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // lwe size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // polynomial size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// },
|
||||
// {mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
|
||||
// rewriter.getContext())});
|
||||
// if (insertForwardDeclaration(op, rewriter,
|
||||
// "allocate_lwe_bootstrap_key_u64",
|
||||
// funcType)
|
||||
// .failed()) {
|
||||
// return mlir::failure();
|
||||
// }
|
||||
// }
|
||||
|
||||
// Insert forward declaration of the getBsk function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
// level
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// baselog
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// glwe size
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// lwe size
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// polynomial size
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
},
|
||||
rewriter.getContext(), {},
|
||||
{mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
|
||||
rewriter.getContext())});
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"allocate_lwe_bootstrap_key_u64", funcType)
|
||||
if (insertForwardDeclaration(op, rewriter, "getGlobalBootstrapKey",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -448,7 +459,8 @@ struct LowLFHEBootstrapLweOpPattern
|
||||
}
|
||||
}
|
||||
|
||||
auto errOp = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
|
||||
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
|
||||
rewriter.getIndexAttr(0));
|
||||
// allocate the result lwe ciphertext
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), mlir::IntegerAttr::get(
|
||||
@@ -458,36 +470,44 @@ struct LowLFHEBootstrapLweOpPattern
|
||||
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, "allocate_lwe_ciphertext_u64", lweOperandType, allocLweCtOperands);
|
||||
// allocate bsk
|
||||
auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
|
||||
auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
|
||||
auto glweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
|
||||
auto polySizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("polynomialSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
mlir::SmallVector<mlir::Value> allocBskOperands{
|
||||
errOp, decompLevelCountOp, decompBaseLogOp,
|
||||
glweSizeOp, lweSizeOp, polySizeOp};
|
||||
auto allocateBskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_bootstrap_key_u64",
|
||||
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto glweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32), -1));
|
||||
// auto polySizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("polynomialSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
// mlir::SmallVector<mlir::Value> allocBskOperands{
|
||||
// errOp, decompLevelCountOp, decompBaseLogOp,
|
||||
// glweSizeOp, lweSizeOp, polySizeOp};
|
||||
// auto allocateBskOp = rewriter.create<mlir::CallOp>(
|
||||
// op.getLoc(), "allocate_lwe_bootstrap_key_u64",
|
||||
// mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
|
||||
// rewriter.getContext()),
|
||||
// allocBskOperands);
|
||||
|
||||
// get bsk
|
||||
mlir::SmallVector<mlir::Value> getBskOperands{};
|
||||
auto getBskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "getGlobalBootstrapKey",
|
||||
mlir::zamalang::LowLFHE::LweBootstrapKeyType::get(
|
||||
rewriter.getContext()),
|
||||
allocBskOperands);
|
||||
getBskOperands);
|
||||
// bootstrap
|
||||
mlir::SmallVector<mlir::Value> bootstrapOperands{
|
||||
errOp, allocateBskOp.getResult(0), allocateLweCtOp.getResult(0),
|
||||
errOp, getBskOp.getResult(0), allocateLweCtOp.getResult(0),
|
||||
op->getOperand(0), op->getOperand(1)};
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), "bootstrap_lwe_u64",
|
||||
mlir::TypeRange({}), bootstrapOperands);
|
||||
@@ -496,9 +516,6 @@ struct LowLFHEBootstrapLweOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// TODO:
|
||||
// Parameterization
|
||||
// Get concrete key
|
||||
struct LowLFHEKeySwitchLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::LowLFHE::KeySwitchLweOp> {
|
||||
LowLFHEKeySwitchLweOpPattern(mlir::MLIRContext *context,
|
||||
@@ -509,28 +526,41 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::LowLFHE::KeySwitchLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto errType =
|
||||
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
auto lweOperandType = op->getOperandTypes().front();
|
||||
// Insert forward declaration of the allocate_ksk_key function
|
||||
// {
|
||||
// auto funcType = mlir::FunctionType::get(
|
||||
// rewriter.getContext(),
|
||||
// {
|
||||
// errType,
|
||||
// // level
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // baselog
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // input lwe size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// // output lwe size
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// },
|
||||
// {mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
|
||||
// rewriter.getContext())});
|
||||
// if (insertForwardDeclaration(op, rewriter,
|
||||
// "allocate_lwe_keyswitch_key_u64",
|
||||
// funcType)
|
||||
// .failed()) {
|
||||
// return mlir::failure();
|
||||
// }
|
||||
// }
|
||||
|
||||
// Insert forward declaration of the getKsk function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
// level
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// baselog
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// input lwe size
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// output lwe size
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
},
|
||||
rewriter.getContext(), {},
|
||||
{mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
|
||||
rewriter.getContext())});
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"allocate_lwe_keyswitch_key_u64", funcType)
|
||||
if (insertForwardDeclaration(op, rewriter, "getGlobalKeyswitchKey",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
@@ -573,7 +603,8 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
}
|
||||
}
|
||||
|
||||
auto errOp = rewriter.create<mlir::memref::AllocaOp>(op.getLoc(), errType);
|
||||
auto errOp = rewriter.create<mlir::ConstantOp>(op.getLoc(),
|
||||
rewriter.getIndexAttr(0));
|
||||
// allocate the result lwe ciphertext
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
@@ -584,37 +615,46 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands);
|
||||
// allocate ksk
|
||||
auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
|
||||
auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
|
||||
auto inputLweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("inputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
auto outputLweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
mlir::SmallVector<mlir::Value> allockskOperands{
|
||||
errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp,
|
||||
outputLweSizeOp};
|
||||
auto allocateKskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_keyswitch_key_u64",
|
||||
// auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto inputLweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("inputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
// auto outputLweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
// op.getLoc(),
|
||||
// mlir::IntegerAttr::get(
|
||||
// mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
// op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
// mlir::SmallVector<mlir::Value> allockskOperands{
|
||||
// errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp,
|
||||
// outputLweSizeOp};
|
||||
// auto allocateKskOp = rewriter.create<mlir::CallOp>(
|
||||
// op.getLoc(), "allocate_lwe_keyswitch_key_u64",
|
||||
// mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
|
||||
// rewriter.getContext()),
|
||||
// allockskOperands);
|
||||
|
||||
// get ksk
|
||||
mlir::SmallVector<mlir::Value> getkskOperands{};
|
||||
auto getKskOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "getGlobalKeyswitchKey",
|
||||
mlir::zamalang::LowLFHE::LweKeySwitchKeyType::get(
|
||||
rewriter.getContext()),
|
||||
allockskOperands);
|
||||
// bootstrap
|
||||
getkskOperands);
|
||||
|
||||
// keyswitch
|
||||
mlir::SmallVector<mlir::Value> keyswitchOperands{
|
||||
errOp, allocateKskOp.getResult(0), allocateLweCtOp.getResult(0),
|
||||
errOp, getKskOp.getResult(0), allocateLweCtOp.getResult(0),
|
||||
op->getOperand(0)};
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
|
||||
mlir::TypeRange({}), keyswitchOperands);
|
||||
|
||||
Reference in New Issue
Block a user