refactor: ins forward decl w generic ty @pass-init

Insert forward declarations with generic types at pass initialization.
More docs for all the pass for lowering LUT
This commit is contained in:
youben11
2021-09-08 15:04:10 +01:00
committed by Quentin Bourgerie
parent d97512f507
commit 746d991af6
13 changed files with 513 additions and 414 deletions

View File

@@ -169,6 +169,47 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
return op.getODSResults(0).front();
}
// This is the rewritting of the HLFHE::ApplyLookupTable operation, it will be
// rewritten as 3 new operations:
// - Create the required GLWE ciphertext out of the plain lookup table
// - Keyswitch the input ciphertext to match the input key of the bootstrapping
// - Bootstrap the keyswitched ciphertext with the constructed GLWE ciphertext
// Example:
// from:
// ```
// "%result = MidLFHE.apply_lookup_table"(% arg0, % tlu){
// k = 1 : i32,
// polynomialSize = 2048 : i32,
// levelKS = 3 : i32,
// baseLogKS = 2 : i32,
// levelBS = 5 : i32,
// baseLogBS = 4 : i32,
// outputSizeKS = 600 : i32
// } : (!MidLFHE.glwe<{2048, 1, 64} {4}>, tensor<16xi4>)
// ->(!MidLFHE.glwe<{2048, 1, 64} {4}>)
// ```
// to:
// ```
// % accumulator =
// "LowLFHE.glwe_from_table"(
// % [[TABLE]]){k = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32}
// : (tensor<16xi4>)
// ->!LowLFHE.glwe_ciphertext
// % keyswitched = "LowLFHE.keyswitch_lwe"(% arg0){
// baseLog = 2 : i32,
// inputLweSize = 1 : i32,
// level = 3 : i32,
// outputLweSize = 600 : i32
// } : (!LowLFHE.lwe_ciphertext<2048, 4>)
// ->!LowLFHE.lwe_ciphertext<600, 4>
// % result = "LowLFHE.bootstrap_lwe"(% keyswitched, % accumulator){
// baseLog = 4 : i32,
// k = 1 : i32,
// level = 5 : i32,
// polynomialSize = 2048 : i32
// } : (!LowLFHE.lwe_ciphertext<600, 4>, !LowLFHE.glwe_ciphertext)
// ->!LowLFHE.lwe_ciphertext<2048, 4>
// ```
mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
mlir::Value ct, mlir::Value table, mlir::IntegerAttr k,
mlir::IntegerAttr polynomialSize,
@@ -178,10 +219,9 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
// convert result type
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
LweCiphertextType lwe_type =
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
convertTypeToLWE(rewriter.getContext(), glwe_type);
// fill the the table in the GLWE accumulator
mlir::IntegerAttr precision = mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), glwe_type.getP());
mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(glwe_type.getP());
mlir::Value accumulator =
rewriter
.create<mlir::zamalang::LowLFHE::GlweFromTable>(
@@ -191,8 +231,8 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
// keyswitch
auto ct_type = ct.getType().cast<GLWECipherTextType>();
mlir::SmallVector<mlir::Value, 1> ksArgs{ct};
mlir::SmallVector<mlir::NamedAttribute, 6> ksAttrs{
mlir::SmallVector<mlir::Value> ksArgs{ct};
mlir::SmallVector<mlir::NamedAttribute> ksAttrs{
mlir::NamedAttribute(
mlir::Identifier::get("inputLweSize", rewriter.getContext()), k),
mlir::NamedAttribute(
@@ -203,16 +243,17 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
mlir::NamedAttribute(
mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS),
};
auto ksOutType = LweCiphertextType::get(
rewriter.getContext(), outputSizeKS.getInt(), ct_type.getP());
mlir::Value keyswitched =
rewriter
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ksArgs,
ksAttrs)
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(loc, ksOutType,
ksArgs, ksAttrs)
.result();
// bootstrap operation
mlir::SmallVector<mlir::Value, 2> bsArgs{keyswitched, accumulator};
mlir::SmallVector<mlir::NamedAttribute, 6> bsAttrs{
mlir::SmallVector<mlir::Value> bsArgs{keyswitched, accumulator};
mlir::SmallVector<mlir::NamedAttribute> bsAttrs{
mlir::NamedAttribute(mlir::Identifier::get("k", rewriter.getContext()),
k),
mlir::NamedAttribute(

View File

@@ -41,7 +41,7 @@ public:
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
void generateRuntimeContext() {
void initGlobalRuntimeContext() {
auto ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
auto bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
setGlobalRuntimeContext(createRuntimeContext(ksk, bsk));