diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h index 86e7c3a96..9e58b4efe 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h @@ -185,10 +185,23 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc, // keyswitch auto ct_type = ct.getType().cast(); + mlir::SmallVector ksArgs{ct}; + mlir::SmallVector ksAttrs{ + mlir::NamedAttribute( + mlir::Identifier::get("inputLweSize", rewriter.getContext()), k), + // TODO: get the actual output size + mlir::NamedAttribute( + mlir::Identifier::get("outputLweSize", rewriter.getContext()), k), + mlir::NamedAttribute( + mlir::Identifier::get("level", rewriter.getContext()), levelKS), + mlir::NamedAttribute( + mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS), + }; mlir::Value keyswitched = rewriter .create( - loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ct) + loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ksArgs, + ksAttrs) .result(); // convert result type diff --git a/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td b/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td index 7fb41332c..b808711b5 100644 --- a/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td +++ b/compiler/include/zamalang/Dialect/LowLFHE/IR/LowLFHEOps.td @@ -93,7 +93,11 @@ def SetPlaintextListElementOp : LowLFHE_Op<"set_plaintext_list_element">{ def KeySwitchLweOp : LowLFHE_Op<"keyswitch_lwe"> { let arguments = (ins // LweKeySwitchKeyType:$keyswitch_key, - LweCiphertextType:$ciphertext + LweCiphertextType:$ciphertext, + I32Attr:$inputLweSize, + I32Attr:$outputLweSize, + I32Attr:$level, + I32Attr:$baseLog ); let results = (outs LweCiphertextType:$result); } diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp index 51b7efa6a..ae97f813a 100644 --- a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -512,7 +512,7 @@ struct LowLFHEKeySwitchLweOpPattern auto errType = mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext())); auto lweOperandType = op->getOperandTypes().front(); - // Insert forward declaration of the allocate_bsk_key function + // Insert forward declaration of the allocate_ksk_key function { auto funcType = mlir::FunctionType::get( rewriter.getContext(), @@ -578,7 +578,8 @@ struct LowLFHEKeySwitchLweOpPattern auto lweSizeOp = rewriter.create( op.getLoc(), mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), -1)); + mlir::IntegerType::get(rewriter.getContext(), 32), + op->getAttr("outputLweSize").cast().getInt())); mlir::SmallVector allocLweCtOperands{errOp, lweSizeOp}; auto allocateLweCtOp = rewriter.replaceOpWithNewOp( op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands); @@ -586,19 +587,23 @@ struct LowLFHEKeySwitchLweOpPattern auto decompLevelCountOp = rewriter.create( op.getLoc(), mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), -1)); + mlir::IntegerType::get(rewriter.getContext(), 32), + op->getAttr("level").cast().getInt())); auto decompBaseLogOp = rewriter.create( op.getLoc(), mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), -1)); + mlir::IntegerType::get(rewriter.getContext(), 32), + op->getAttr("baseLog").cast().getInt())); auto inputLweSizeOp = rewriter.create( op.getLoc(), mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), -1)); + mlir::IntegerType::get(rewriter.getContext(), 32), + op->getAttr("inputLweSize").cast().getInt())); auto outputLweSizeOp = rewriter.create( op.getLoc(), mlir::IntegerAttr::get( - mlir::IntegerType::get(rewriter.getContext(), 32), -1)); + mlir::IntegerType::get(rewriter.getContext(), 32), + op->getAttr("outputLweSize").cast().getInt())); mlir::SmallVector allockskOperands{ errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp, outputLweSizeOp}; @@ -608,11 +613,11 @@ struct LowLFHEKeySwitchLweOpPattern rewriter.getContext()), allockskOperands); // bootstrap - mlir::SmallVector bootstrapOperands{ + mlir::SmallVector keyswitchOperands{ errOp, allocateKskOp.getResult(0), allocateLweCtOp.getResult(0), op->getOperand(0)}; rewriter.create(op.getLoc(), "keyswitch_lwe_u64", - mlir::TypeRange({}), bootstrapOperands); + mlir::TypeRange({}), keyswitchOperands); return mlir::success(); }; diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir index 68031d175..70c923083 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir @@ -9,13 +9,13 @@ func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe // CHECK-NEXT: %[[V0:.*]] = memref.alloca() : memref // CHECK-NEXT: %[[C0:.*]] = constant 1 : i32 // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[V0]], %[[C0]]) : (memref, i32) -> !LowLFHE.lwe_ciphertext<1024,4> - // CHECK-NEXT: %[[C1:.*]] = constant -1 : i32 - // CHECK-NEXT: %[[C2:.*]] = constant -1 : i32 + // CHECK-NEXT: %[[C1:.*]] = constant 3 : i32 + // CHECK-NEXT: %[[C2:.*]] = constant 2 : i32 // CHECK-NEXT: %[[C3:.*]] = constant -1 : i32 // CHECK-NEXT: %[[C4:.*]] = constant 1024 : i32 // CHECK-NEXT: %[[V2:.*]] = call @allocate_lwe_bootstrap_key_u64(%0, %[[C1]], %[[C2]], %[[C3]], %[[C0]], %[[C4]]) : (memref, i32, i32, i32, i32, i32) -> !LowLFHE.lwe_bootstrap_key // CHECK-NEXT: call @bootstrap_lwe_u64(%[[V0]], %[[V2]], %[[V1]], %arg0, %arg1) : (memref, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> () // CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4> - %1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, k = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> + %1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> } \ No newline at end of file diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir index f93d03c76..753ec67f6 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir @@ -7,15 +7,15 @@ // CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> { // CHECK-NEXT: %[[V0:.*]] = memref.alloca() : memref - // CHECK-NEXT: %[[C0:.*]] = constant -1 : i32 + // CHECK-NEXT: %[[C0:.*]] = constant 1 : i32 // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[V0]], %[[C0]]) : (memref, i32) -> !LowLFHE.lwe_ciphertext<1024,4> - // CHECK-NEXT: %[[C1:.*]] = constant -1 : i32 - // CHECK-NEXT: %[[C2:.*]] = constant -1 : i32 - // CHECK-NEXT: %[[C3:.*]] = constant -1 : i32 - // CHECK-NEXT: %[[C4:.*]] = constant -1 : i32 + // CHECK-NEXT: %[[C1:.*]] = constant 3 : i32 + // CHECK-NEXT: %[[C2:.*]] = constant 2 : i32 + // CHECK-NEXT: %[[C3:.*]] = constant 1 : i32 + // CHECK-NEXT: %[[C4:.*]] = constant 1 : i32 // CHECK-NEXT: %[[V2:.*]] = call @allocate_lwe_keyswitch_key_u64(%0, %[[C1]], %[[C2]], %[[C3]], %[[C4]]) : (memref, i32, i32, i32, i32) -> !LowLFHE.lwe_key_switch_key // CHECK-NEXT: call @keyswitch_lwe_u64(%[[V0]], %[[V2]], %[[V1]], %arg0) : (memref, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -> () // CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4> - %1 = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> + %1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> } \ No newline at end of file diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index 905098fa1..b222d8c1e 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -3,9 +3,9 @@ // CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> { // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%arg1) {k = 1 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext - // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> - // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%1, %0) {baseLog = -1 : i32, k = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4> - %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=-1:i32, baseLogKS=-1:i32, levelBS=-1:i32, baseLogBS=-1:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) + %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> } \ No newline at end of file diff --git a/compiler/tests/Dialect/LowLFHE/ops.mlir b/compiler/tests/Dialect/LowLFHE/ops.mlir index 1a82ecca7..be4315c9d 100644 --- a/compiler/tests/Dialect/LowLFHE/ops.mlir +++ b/compiler/tests/Dialect/LowLFHE/ops.mlir @@ -102,10 +102,10 @@ func @set_plaintext_list_element(%arg0: !LowLFHE.plaintext_list, %arg1: index, % // CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> + // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> // CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<2048,7> - %1 = "LowLFHE.keyswitch_lwe"(%arg0): (!LowLFHE.lwe_ciphertext<2048,7>) -> (!LowLFHE.lwe_ciphertext<2048,7>) + %1 = "LowLFHE.keyswitch_lwe"(%arg0){baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32}: (!LowLFHE.lwe_ciphertext<2048,7>) -> (!LowLFHE.lwe_ciphertext<2048,7>) return %1: !LowLFHE.lwe_ciphertext<2048,7> }