feat: parameterize KS operation

This commit is contained in:
youben11
2021-08-25 14:59:53 +01:00
committed by Quentin Bourgerie
parent 14f171bef9
commit 6e2ac3af4e
7 changed files with 46 additions and 24 deletions

View File

@@ -185,10 +185,23 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
// keyswitch
auto ct_type = ct.getType().cast<GLWECipherTextType>();
mlir::SmallVector<mlir::Value, 1> ksArgs{ct};
mlir::SmallVector<mlir::NamedAttribute, 6> ksAttrs{
mlir::NamedAttribute(
mlir::Identifier::get("inputLweSize", rewriter.getContext()), k),
// TODO: get the actual output size
mlir::NamedAttribute(
mlir::Identifier::get("outputLweSize", rewriter.getContext()), k),
mlir::NamedAttribute(
mlir::Identifier::get("level", rewriter.getContext()), levelKS),
mlir::NamedAttribute(
mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS),
};
mlir::Value keyswitched =
rewriter
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ct)
loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ksArgs,
ksAttrs)
.result();
// convert result type

View File

@@ -93,7 +93,11 @@ def SetPlaintextListElementOp : LowLFHE_Op<"set_plaintext_list_element">{
def KeySwitchLweOp : LowLFHE_Op<"keyswitch_lwe"> {
let arguments = (ins
// LweKeySwitchKeyType:$keyswitch_key,
LweCiphertextType:$ciphertext
LweCiphertextType:$ciphertext,
I32Attr:$inputLweSize,
I32Attr:$outputLweSize,
I32Attr:$level,
I32Attr:$baseLog
);
let results = (outs LweCiphertextType:$result);
}

View File

@@ -512,7 +512,7 @@ struct LowLFHEKeySwitchLweOpPattern
auto errType =
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
auto lweOperandType = op->getOperandTypes().front();
// Insert forward declaration of the allocate_bsk_key function
// Insert forward declaration of the allocate_ksk_key function
{
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
@@ -578,7 +578,8 @@ struct LowLFHEKeySwitchLweOpPattern
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands);
@@ -586,19 +587,23 @@ struct LowLFHEKeySwitchLweOpPattern
auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
auto inputLweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("inputLweSize").cast<mlir::IntegerAttr>().getInt()));
auto outputLweSizeOp = rewriter.create<mlir::ConstantOp>(
op.getLoc(),
mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
mlir::IntegerType::get(rewriter.getContext(), 32),
op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
mlir::SmallVector<mlir::Value> allockskOperands{
errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp,
outputLweSizeOp};
@@ -608,11 +613,11 @@ struct LowLFHEKeySwitchLweOpPattern
rewriter.getContext()),
allockskOperands);
// bootstrap
mlir::SmallVector<mlir::Value> bootstrapOperands{
mlir::SmallVector<mlir::Value> keyswitchOperands{
errOp, allocateKskOp.getResult(0), allocateLweCtOp.getResult(0),
op->getOperand(0)};
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
mlir::TypeRange({}), bootstrapOperands);
mlir::TypeRange({}), keyswitchOperands);
return mlir::success();
};

View File

@@ -9,13 +9,13 @@ func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe
// CHECK-NEXT: %[[V0:.*]] = memref.alloca() : memref<index>
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[V0]], %[[C0]]) : (memref<index>, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[C1:.*]] = constant -1 : i32
// CHECK-NEXT: %[[C2:.*]] = constant -1 : i32
// CHECK-NEXT: %[[C1:.*]] = constant 3 : i32
// CHECK-NEXT: %[[C2:.*]] = constant 2 : i32
// CHECK-NEXT: %[[C3:.*]] = constant -1 : i32
// CHECK-NEXT: %[[C4:.*]] = constant 1024 : i32
// CHECK-NEXT: %[[V2:.*]] = call @allocate_lwe_bootstrap_key_u64(%0, %[[C1]], %[[C2]], %[[C3]], %[[C0]], %[[C4]]) : (memref<index>, i32, i32, i32, i32, i32) -> !LowLFHE.lwe_bootstrap_key
// CHECK-NEXT: call @bootstrap_lwe_u64(%[[V0]], %[[V2]], %[[V1]], %arg0, %arg1) : (memref<index>, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> ()
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, k = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
%1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
return %1: !LowLFHE.lwe_ciphertext<1024,4>
}

View File

@@ -7,15 +7,15 @@
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> {
// CHECK-NEXT: %[[V0:.*]] = memref.alloca() : memref<index>
// CHECK-NEXT: %[[C0:.*]] = constant -1 : i32
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[V0]], %[[C0]]) : (memref<index>, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[C1:.*]] = constant -1 : i32
// CHECK-NEXT: %[[C2:.*]] = constant -1 : i32
// CHECK-NEXT: %[[C3:.*]] = constant -1 : i32
// CHECK-NEXT: %[[C4:.*]] = constant -1 : i32
// CHECK-NEXT: %[[C1:.*]] = constant 3 : i32
// CHECK-NEXT: %[[C2:.*]] = constant 2 : i32
// CHECK-NEXT: %[[C3:.*]] = constant 1 : i32
// CHECK-NEXT: %[[C4:.*]] = constant 1 : i32
// CHECK-NEXT: %[[V2:.*]] = call @allocate_lwe_keyswitch_key_u64(%0, %[[C1]], %[[C2]], %[[C3]], %[[C4]]) : (memref<index>, i32, i32, i32, i32) -> !LowLFHE.lwe_key_switch_key
// CHECK-NEXT: call @keyswitch_lwe_u64(%[[V0]], %[[V2]], %[[V1]], %arg0) : (memref<index>, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -> ()
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
%1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
return %1: !LowLFHE.lwe_ciphertext<1024,4>
}

View File

@@ -3,9 +3,9 @@
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> {
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%arg1) {k = 1 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%1, %0) {baseLog = -1 : i32, k = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4>
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=-1:i32, baseLogKS=-1:i32, levelBS=-1:i32, baseLogBS=-1:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
}

View File

@@ -102,10 +102,10 @@ func @set_plaintext_list_element(%arg0: !LowLFHE.plaintext_list, %arg1: index, %
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<2048,7>
%1 = "LowLFHE.keyswitch_lwe"(%arg0): (!LowLFHE.lwe_ciphertext<2048,7>) -> (!LowLFHE.lwe_ciphertext<2048,7>)
%1 = "LowLFHE.keyswitch_lwe"(%arg0){baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32}: (!LowLFHE.lwe_ciphertext<2048,7>) -> (!LowLFHE.lwe_ciphertext<2048,7>)
return %1: !LowLFHE.lwe_ciphertext<2048,7>
}