mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor: ins forward decl w generic ty @pass-init
Insert forward declarations with generic types at pass initialization. More docs for all the pass for lowering LUT
This commit is contained in:
committed by
Quentin Bourgerie
parent
d97512f507
commit
746d991af6
@@ -169,6 +169,47 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter,
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
// This is the rewritting of the HLFHE::ApplyLookupTable operation, it will be
|
||||
// rewritten as 3 new operations:
|
||||
// - Create the required GLWE ciphertext out of the plain lookup table
|
||||
// - Keyswitch the input ciphertext to match the input key of the bootstrapping
|
||||
// - Bootstrap the keyswitched ciphertext with the constructed GLWE ciphertext
|
||||
// Example:
|
||||
// from:
|
||||
// ```
|
||||
// "%result = MidLFHE.apply_lookup_table"(% arg0, % tlu){
|
||||
// k = 1 : i32,
|
||||
// polynomialSize = 2048 : i32,
|
||||
// levelKS = 3 : i32,
|
||||
// baseLogKS = 2 : i32,
|
||||
// levelBS = 5 : i32,
|
||||
// baseLogBS = 4 : i32,
|
||||
// outputSizeKS = 600 : i32
|
||||
// } : (!MidLFHE.glwe<{2048, 1, 64} {4}>, tensor<16xi4>)
|
||||
// ->(!MidLFHE.glwe<{2048, 1, 64} {4}>)
|
||||
// ```
|
||||
// to:
|
||||
// ```
|
||||
// % accumulator =
|
||||
// "LowLFHE.glwe_from_table"(
|
||||
// % [[TABLE]]){k = 1 : i32, p = 4 : i32, polynomialSize = 2048 : i32}
|
||||
// : (tensor<16xi4>)
|
||||
// ->!LowLFHE.glwe_ciphertext
|
||||
// % keyswitched = "LowLFHE.keyswitch_lwe"(% arg0){
|
||||
// baseLog = 2 : i32,
|
||||
// inputLweSize = 1 : i32,
|
||||
// level = 3 : i32,
|
||||
// outputLweSize = 600 : i32
|
||||
// } : (!LowLFHE.lwe_ciphertext<2048, 4>)
|
||||
// ->!LowLFHE.lwe_ciphertext<600, 4>
|
||||
// % result = "LowLFHE.bootstrap_lwe"(% keyswitched, % accumulator){
|
||||
// baseLog = 4 : i32,
|
||||
// k = 1 : i32,
|
||||
// level = 5 : i32,
|
||||
// polynomialSize = 2048 : i32
|
||||
// } : (!LowLFHE.lwe_ciphertext<600, 4>, !LowLFHE.glwe_ciphertext)
|
||||
// ->!LowLFHE.lwe_ciphertext<2048, 4>
|
||||
// ```
|
||||
mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
|
||||
mlir::Value ct, mlir::Value table, mlir::IntegerAttr k,
|
||||
mlir::IntegerAttr polynomialSize,
|
||||
@@ -178,10 +219,9 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
|
||||
// convert result type
|
||||
GLWECipherTextType glwe_type = result.getType().cast<GLWECipherTextType>();
|
||||
LweCiphertextType lwe_type =
|
||||
convertTypeGLWEToLWE(rewriter.getContext(), glwe_type);
|
||||
convertTypeToLWE(rewriter.getContext(), glwe_type);
|
||||
// fill the the table in the GLWE accumulator
|
||||
mlir::IntegerAttr precision = mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32), glwe_type.getP());
|
||||
mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(glwe_type.getP());
|
||||
mlir::Value accumulator =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::GlweFromTable>(
|
||||
@@ -191,8 +231,8 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
|
||||
|
||||
// keyswitch
|
||||
auto ct_type = ct.getType().cast<GLWECipherTextType>();
|
||||
mlir::SmallVector<mlir::Value, 1> ksArgs{ct};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 6> ksAttrs{
|
||||
mlir::SmallVector<mlir::Value> ksArgs{ct};
|
||||
mlir::SmallVector<mlir::NamedAttribute> ksAttrs{
|
||||
mlir::NamedAttribute(
|
||||
mlir::Identifier::get("inputLweSize", rewriter.getContext()), k),
|
||||
mlir::NamedAttribute(
|
||||
@@ -203,16 +243,17 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
|
||||
mlir::NamedAttribute(
|
||||
mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS),
|
||||
};
|
||||
auto ksOutType = LweCiphertextType::get(
|
||||
rewriter.getContext(), outputSizeKS.getInt(), ct_type.getP());
|
||||
mlir::Value keyswitched =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
|
||||
loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ksArgs,
|
||||
ksAttrs)
|
||||
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(loc, ksOutType,
|
||||
ksArgs, ksAttrs)
|
||||
.result();
|
||||
|
||||
// bootstrap operation
|
||||
mlir::SmallVector<mlir::Value, 2> bsArgs{keyswitched, accumulator};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 6> bsAttrs{
|
||||
mlir::SmallVector<mlir::Value> bsArgs{keyswitched, accumulator};
|
||||
mlir::SmallVector<mlir::NamedAttribute> bsAttrs{
|
||||
mlir::NamedAttribute(mlir::Identifier::get("k", rewriter.getContext()),
|
||||
k),
|
||||
mlir::NamedAttribute(
|
||||
|
||||
@@ -41,7 +41,7 @@ public:
|
||||
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
|
||||
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
|
||||
|
||||
void generateRuntimeContext() {
|
||||
void initGlobalRuntimeContext() {
|
||||
auto ksk = std::get<1>(this->keyswitchKeys["ksk_v0"]);
|
||||
auto bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
|
||||
setGlobalRuntimeContext(createRuntimeContext(ksk, bsk));
|
||||
|
||||
Reference in New Issue
Block a user