From 3b5ae0657d14bd9e54b7b08c07395ba85a1b9e4c Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 19 Aug 2021 09:08:46 +0100 Subject: [PATCH] feat: MidToLowLFHE lowering of apply_lut --- .../Conversion/MidLFHEToLowLFHE/Patterns.h | 36 +++++++++++++++++++ .../Conversion/MidLFHEToLowLFHE/Patterns.td | 6 ++++ .../MidLFHEToLowLFHE/apply_lookup_table.mlir | 11 ++++++ 3 files changed, 53 insertions(+) create mode 100644 compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h index 78e2b1356..4d4b83920 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h @@ -169,6 +169,42 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter rewriter, return op.getODSResults(0).front(); } +mlir::Value createPBS(mlir::PatternRewriter rewriter, mlir::Location loc, + mlir::Value ct, mlir::Value table, mlir::IntegerAttr k, + mlir::IntegerAttr polynomialSize, + mlir::IntegerAttr levelKS, mlir::IntegerAttr baseLogKS, + mlir::IntegerAttr levelBS, mlir::IntegerAttr baseLogBS, + mlir::OpResult result) { + // fill the the table in the GLWE accumulator + mlir::Value accumulator = + rewriter + .create( + loc, LowLFHE::GlweCiphertextType::get(rewriter.getContext()), + table, polynomialSize, k) + .result(); + + // keyswitch + auto ct_type = ct.getType().cast(); + mlir::Value keyswitched = + rewriter + .create( + loc, convertTypeGLWEToLWE(rewriter.getContext(), ct_type), ct) + .result(); + + // convert result type + GLWECipherTextType glwe_type = result.getType().cast(); + LweCiphertextType lwe_type = + convertTypeGLWEToLWE(rewriter.getContext(), glwe_type); + // bootstrap operation + mlir::Value bootstrapped = + rewriter + .create( + loc, lwe_type, keyswitched, accumulator) + .result(); + + return bootstrapped; +} + } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td index fafc648a2..7ed5a2700 100644 --- a/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td +++ b/compiler/include/zamalang/Conversion/MidLFHEToLowLFHE/Patterns.td @@ -35,4 +35,10 @@ def SubIntGLWEPattern : Pat< (SubIntGLWEOp:$result $arg0, $arg1), (createSubIntLweOp $arg0, $arg1, $result)>; +def createPBS : NativeCodeCall<"mlir::zamalang::createPBS($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6, $7, $8)">; + +def ApplyLookupTableGLWEPattern : Pat< + (ApplyLookupTable:$result $ct, $table, $k, $polynomialSize, $levelKS, $baseLogKS, $levelBS, $baseLogBS), + (createPBS $ct, $table, $k, $polynomialSize, $levelKS, $baseLogKS, $levelBS, $baseLogBS, $result)>; + #endif diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir new file mode 100644 index 000000000..30edc1a2c --- /dev/null +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -0,0 +1,11 @@ +// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4> +func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> { + // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%arg1) {k = 1 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4> + %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=-1:i32, baseLogKS=-1:i32, levelBS=-1:i32, baseLogBS=-1:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) + return %1: !MidLFHE.glwe<{1024,1,64}{4}> +} \ No newline at end of file