mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: parameterize KS operation
This commit is contained in:
committed by
Quentin Bourgerie
parent
14f171bef9
commit
6e2ac3af4e
@@ -185,10 +185,23 @@ mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc,
|
||||
|
||||
// keyswitch
|
||||
auto ct_type = ct.getType().cast<GLWECipherTextType>();
|
||||
mlir::SmallVector<mlir::Value, 1> ksArgs{ct};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 6> ksAttrs{
|
||||
mlir::NamedAttribute(
|
||||
mlir::Identifier::get("inputLweSize", rewriter.getContext()), k),
|
||||
// TODO: get the actual output size
|
||||
mlir::NamedAttribute(
|
||||
mlir::Identifier::get("outputLweSize", rewriter.getContext()), k),
|
||||
mlir::NamedAttribute(
|
||||
mlir::Identifier::get("level", rewriter.getContext()), levelKS),
|
||||
mlir::NamedAttribute(
|
||||
mlir::Identifier::get("baseLog", rewriter.getContext()), baseLogKS),
|
||||
};
|
||||
mlir::Value keyswitched =
|
||||
rewriter
|
||||
.create<mlir::zamalang::LowLFHE::KeySwitchLweOp>(
|
||||
loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ct)
|
||||
loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ksArgs,
|
||||
ksAttrs)
|
||||
.result();
|
||||
|
||||
// convert result type
|
||||
|
||||
@@ -93,7 +93,11 @@ def SetPlaintextListElementOp : LowLFHE_Op<"set_plaintext_list_element">{
|
||||
def KeySwitchLweOp : LowLFHE_Op<"keyswitch_lwe"> {
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
LweCiphertextType:$ciphertext
|
||||
LweCiphertextType:$ciphertext,
|
||||
I32Attr:$inputLweSize,
|
||||
I32Attr:$outputLweSize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
);
|
||||
let results = (outs LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
@@ -512,7 +512,7 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
auto errType =
|
||||
mlir::MemRefType::get({}, mlir::IndexType::get(rewriter.getContext()));
|
||||
auto lweOperandType = op->getOperandTypes().front();
|
||||
// Insert forward declaration of the allocate_bsk_key function
|
||||
// Insert forward declaration of the allocate_ksk_key function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
@@ -578,7 +578,8 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
auto lweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
|
||||
auto allocateLweCtOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, "allocate_lwe_ciphertext_u64", lweOutputType, allocLweCtOperands);
|
||||
@@ -586,19 +587,23 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
auto decompLevelCountOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("level").cast<mlir::IntegerAttr>().getInt()));
|
||||
auto decompBaseLogOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("baseLog").cast<mlir::IntegerAttr>().getInt()));
|
||||
auto inputLweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("inputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
auto outputLweSizeOp = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(),
|
||||
mlir::IntegerAttr::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32), -1));
|
||||
mlir::IntegerType::get(rewriter.getContext(), 32),
|
||||
op->getAttr("outputLweSize").cast<mlir::IntegerAttr>().getInt()));
|
||||
mlir::SmallVector<mlir::Value> allockskOperands{
|
||||
errOp, decompLevelCountOp, decompBaseLogOp, inputLweSizeOp,
|
||||
outputLweSizeOp};
|
||||
@@ -608,11 +613,11 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
rewriter.getContext()),
|
||||
allockskOperands);
|
||||
// bootstrap
|
||||
mlir::SmallVector<mlir::Value> bootstrapOperands{
|
||||
mlir::SmallVector<mlir::Value> keyswitchOperands{
|
||||
errOp, allocateKskOp.getResult(0), allocateLweCtOp.getResult(0),
|
||||
op->getOperand(0)};
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
|
||||
mlir::TypeRange({}), bootstrapOperands);
|
||||
mlir::TypeRange({}), keyswitchOperands);
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
@@ -9,13 +9,13 @@ func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe
|
||||
// CHECK-NEXT: %[[V0:.*]] = memref.alloca() : memref<index>
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
|
||||
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[V0]], %[[C0]]) : (memref<index>, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant -1 : i32
|
||||
// CHECK-NEXT: %[[C2:.*]] = constant -1 : i32
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 3 : i32
|
||||
// CHECK-NEXT: %[[C2:.*]] = constant 2 : i32
|
||||
// CHECK-NEXT: %[[C3:.*]] = constant -1 : i32
|
||||
// CHECK-NEXT: %[[C4:.*]] = constant 1024 : i32
|
||||
// CHECK-NEXT: %[[V2:.*]] = call @allocate_lwe_bootstrap_key_u64(%0, %[[C1]], %[[C2]], %[[C3]], %[[C0]], %[[C4]]) : (memref<index>, i32, i32, i32, i32, i32) -> !LowLFHE.lwe_bootstrap_key
|
||||
// CHECK-NEXT: call @bootstrap_lwe_u64(%[[V0]], %[[V2]], %[[V1]], %arg0, %arg1) : (memref<index>, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> ()
|
||||
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4>
|
||||
%1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, k = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
%1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
return %1: !LowLFHE.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -7,15 +7,15 @@
|
||||
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = memref.alloca() : memref<index>
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant -1 : i32
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 1 : i32
|
||||
// CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[V0]], %[[C0]]) : (memref<index>, i32) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant -1 : i32
|
||||
// CHECK-NEXT: %[[C2:.*]] = constant -1 : i32
|
||||
// CHECK-NEXT: %[[C3:.*]] = constant -1 : i32
|
||||
// CHECK-NEXT: %[[C4:.*]] = constant -1 : i32
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 3 : i32
|
||||
// CHECK-NEXT: %[[C2:.*]] = constant 2 : i32
|
||||
// CHECK-NEXT: %[[C3:.*]] = constant 1 : i32
|
||||
// CHECK-NEXT: %[[C4:.*]] = constant 1 : i32
|
||||
// CHECK-NEXT: %[[V2:.*]] = call @allocate_lwe_keyswitch_key_u64(%0, %[[C1]], %[[C2]], %[[C3]], %[[C4]]) : (memref<index>, i32, i32, i32, i32) -> !LowLFHE.lwe_key_switch_key
|
||||
// CHECK-NEXT: call @keyswitch_lwe_u64(%[[V0]], %[[V2]], %[[V1]], %arg0) : (memref<index>, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -> ()
|
||||
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4>
|
||||
%1 = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
%1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
return %1: !LowLFHE.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -3,9 +3,9 @@
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%arg1) {k = 1 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%1, %0) {baseLog = -1 : i32, k = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4>
|
||||
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=-1:i32, baseLogKS=-1:i32, levelBS=-1:i32, baseLogBS=-1:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
|
||||
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !MidLFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
@@ -102,10 +102,10 @@ func @set_plaintext_list_element(%arg0: !LowLFHE.plaintext_list, %arg1: index, %
|
||||
|
||||
// CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
|
||||
func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<2048,7>
|
||||
|
||||
%1 = "LowLFHE.keyswitch_lwe"(%arg0): (!LowLFHE.lwe_ciphertext<2048,7>) -> (!LowLFHE.lwe_ciphertext<2048,7>)
|
||||
%1 = "LowLFHE.keyswitch_lwe"(%arg0){baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32}: (!LowLFHE.lwe_ciphertext<2048,7>) -> (!LowLFHE.lwe_ciphertext<2048,7>)
|
||||
return %1: !LowLFHE.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user