diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir new file mode 100644 index 000000000..3790ea57b --- /dev/null +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir @@ -0,0 +1,11 @@ +// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s + +// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> +func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { + // CHECK-NEXT: %[[TABLE:.*]] = constant dense<"0x00000000000000000000000000008000000000000000000100000000000080010000000000000002000000000000800200000000000000030000000000008003000000000000000400000000000080040000000000000005000000000000800500000000000000060000000000008006000000000000000700000000000080070000000000000008000000000000800800000000000000090000000000008009000000000000000A000000000000800A000000000000000B000000000000800B000000000000000C000000000000800C000000000000000D000000000000800D000000000000000E000000000000800E000000000000000F000000000000800F00000000000000100000000000008010000000000000001100000000000080110000000000000012000000000000801200000000000000130000000000008013000000000000001400000000000080140000000000000015000000000000801500000000000000160000000000008016000000000000001700000000000080170000000000000018000000000000801800000000000000190000000000008019000000000000001A000000000000801A000000000000001B000000000000801B000000000000001C000000000000801C000000000000001D000000000000801D000000000000001E000000000000801E000000000000001F000000000000801F00000000000000200000000000008020000000000000002100000000000080210000000000000022000000000000802200000000000000230000000000008023000000000000002400000000000080240000000000000025000000000000802500000000000000260000000000008026000000000000002700000000000080270000000000000028000000000000802800000000000000290000000000008029000000000000002A000000000000802A000000000000002B000000000000802B000000000000002C000000000000802C000000000000002D000000000000802D000000000000002E000000000000802E000000000000002F000000000000802F00000000000000300000000000008030000000000000003100000000000080310000000000000032000000000000803200000000000000330000000000008033000000000000003400000000000080340000000000000035000000000000803500000000000000360000000000008036000000000000003700000000000080370000000000000038000000000000803800000000000000390000000000008039000000000000003A000000000000803A000000000000003B000000000000803B000000000000003C000000000000803C000000000000003D000000000000803D000000000000003E000000000000803E000000000000003F000000000000803F"> : tensor<128xi64> + // CHECK-NEXT: %[[V0:.*]] = "MidLFHE.apply_lookup_table"(%arg0, %[[TABLE]]) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{7}>, tensor<128xi64>) -> !MidLFHE.glwe<{_,_,_}{7}> + // CHECK-NEXT: return %[[V0]] : !MidLFHE.glwe<{_,_,_}{7}> + %tlu = std.constant dense<[0, 36028797018963968, 72057594037927936, 108086391056891904, 144115188075855872, 180143985094819840, 216172782113783808, 252201579132747776, 288230376151711744, 324259173170675712, 360287970189639680, 396316767208603648, 432345564227567616, 468374361246531584, 504403158265495552, 540431955284459520, 576460752303423488, 612489549322387456, 648518346341351424, 684547143360315392, 720575940379279360, 756604737398243328, 792633534417207296, 828662331436171264, 864691128455135232, 900719925474099200, 936748722493063168, 972777519512027136, 1008806316530991104, 1044835113549955072, 1080863910568919040, 1116892707587883008, 1152921504606846976, 1188950301625810944, 1224979098644774912, 1261007895663738880, 1297036692682702848, 1333065489701666816, 1369094286720630784, 1405123083739594752, 1441151880758558720, 1477180677777522688, 1513209474796486656, 1549238271815450624, 1585267068834414592, 1621295865853378560, 1657324662872342528, 1693353459891306496, 1729382256910270464, 1765411053929234432, 1801439850948198400, 1837468647967162368, 1873497444986126336, 1909526242005090304, 1945555039024054272, 1981583836043018240, 2017612633061982208, 2053641430080946176, 2089670227099910144, 2125699024118874112, 2161727821137838080, 2197756618156802048, 2233785415175766016, 2269814212194729984, 2305843009213693952, 2341871806232657920, 2377900603251621888, 2413929400270585856, 2449958197289549824, 2485986994308513792, 2522015791327477760, 2558044588346441728, 2594073385365405696, 2630102182384369664, 2666130979403333632, 2702159776422297600, 2738188573441261568, 2774217370460225536, 2810246167479189504, 2846274964498153472, 2882303761517117440, 2918332558536081408, 2954361355555045376, 2990390152574009344, 3026418949592973312, 3062447746611937280, 3098476543630901248, 3134505340649865216, 3170534137668829184, 3206562934687793152, 3242591731706757120, 3278620528725721088, 3314649325744685056, 3350678122763649024, 3386706919782612992, 3422735716801576960, 3458764513820540928, 3494793310839504896, 3530822107858468864, 3566850904877432832, 3602879701896396800, 3638908498915360768, 3674937295934324736, 3710966092953288704, 3746994889972252672, 3783023686991216640, 3819052484010180608, 3855081281029144576, 3891110078048108544, 3927138875067072512, 3963167672086036480, 3999196469105000448, 4035225266123964416, 4071254063142928384, 4107282860161892352, 4143311657180856320, 4179340454199820288, 4215369251218784256, 4251398048237748224, 4287426845256712192, 4323455642275676160, 4359484439294640128, 4395513236313604096, 4431542033332568064, 4467570830351532032, 4503599627370496000, 4539628424389459968, 4575657221408423936]> : tensor<128xi64> + %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> +} \ No newline at end of file diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir index 70c923083..0aad472ba 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir @@ -1,20 +1,16 @@ // RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s // CHECK-LABEL: module -// CHECK-NEXT: func private @bootstrap_lwe_u64(memref, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(memref, i32) -> !LowLFHE.lwe_ciphertext<1024,4> -// CHECK-NEXT: func private @allocate_lwe_bootstrap_key_u64(memref, i32, i32, i32, i32, i32) -> !LowLFHE.lwe_bootstrap_key +// CHECK-NEXT: func private @bootstrap_lwe_u64(index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) +// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<1024,4> +// CHECK-NEXT: func private @getGlobalBootstrapKey() -> !LowLFHE.lwe_bootstrap_key // CHECK-LABEL: func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> { - // CHECK-NEXT: %[[V0:.*]] = memref.alloca() : memref + // CHECK-NEXT: %[[ERR:.*]] = constant 0 : index // CHECK-NEXT: %[[C0:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[V0]], %[[C0]]) : (memref, i32) -> !LowLFHE.lwe_ciphertext<1024,4> - // CHECK-NEXT: %[[C1:.*]] = constant 3 : i32 - // CHECK-NEXT: %[[C2:.*]] = constant 2 : i32 - // CHECK-NEXT: %[[C3:.*]] = constant -1 : i32 - // CHECK-NEXT: %[[C4:.*]] = constant 1024 : i32 - // CHECK-NEXT: %[[V2:.*]] = call @allocate_lwe_bootstrap_key_u64(%0, %[[C1]], %[[C2]], %[[C3]], %[[C0]], %[[C4]]) : (memref, i32, i32, i32, i32, i32) -> !LowLFHE.lwe_bootstrap_key - // CHECK-NEXT: call @bootstrap_lwe_u64(%[[V0]], %[[V2]], %[[V1]], %arg0, %arg1) : (memref, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> () + // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V2:.*]] = call @getGlobalBootstrapKey() : () -> !LowLFHE.lwe_bootstrap_key + // CHECK-NEXT: call @bootstrap_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %arg0, %arg1) : (index, !LowLFHE.lwe_bootstrap_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> () // CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir index 2b8f7b353..bb7b187b9 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir @@ -1,23 +1,23 @@ // RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s // CHECK-LABEL: module -// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(memref, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) -// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(memref, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -// CHECK-NEXT: func private @foreign_plaintext_list_u64(memref, tensor<16xi4>, i64) -> !LowLFHE.foreign_plaintext_list -// CHECK-NEXT: func private @allocate_plaintext_list_u64(memref, i32) -> !LowLFHE.plaintext_list -// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(memref, i32, i32) -> !LowLFHE.glwe_ciphertext +// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) +// CHECK-NEXT: func private @fill_plaintext_list_with_expansion_u64(index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) +// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi4>, i64) -> !LowLFHE.foreign_plaintext_list +// CHECK-NEXT: func private @allocate_plaintext_list_u64(index, i32) -> !LowLFHE.plaintext_list +// CHECK-NEXT: func private @allocate_glwe_ciphertext_u64(index, i32, i32) -> !LowLFHE.glwe_ciphertext // CHECK-LABEL: func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext func @glwe_from_table(%arg0: tensor<16xi4>) -> !LowLFHE.glwe_ciphertext { - // CHECK-NEXT: %[[V0:.*]] = memref.alloca() : memref + // CHECK-NEXT: %[[V0:.*]] = constant 0 : index // CHECK-NEXT: %[[C0:.*]] = constant 1 : i32 // CHECK-NEXT: %[[C1:.*]] = constant 1024 : i32 - // CHECK-NEXT: %[[V1:.*]] = call @allocate_glwe_ciphertext_u64(%[[V0]], %[[C0]], %[[C1]]) : (memref, i32, i32) -> !LowLFHE.glwe_ciphertext - // CHECK-NEXT: %[[V2:.*]] = call @allocate_glwe_ciphertext_u64(%[[V0]], %[[C0]], %[[C1]]) : (memref, i32, i32) -> !LowLFHE.glwe_ciphertext - // CHECK-NEXT: %[[V3:.*]] = call @allocate_plaintext_list_u64(%[[V0]], %[[C1]]) : (memref, i32) -> !LowLFHE.plaintext_list + // CHECK-NEXT: %[[V1:.*]] = call @allocate_glwe_ciphertext_u64(%[[V0]], %[[C0]], %[[C1]]) : (index, i32, i32) -> !LowLFHE.glwe_ciphertext + // CHECK-NEXT: %[[V2:.*]] = call @allocate_glwe_ciphertext_u64(%[[V0]], %[[C0]], %[[C1]]) : (index, i32, i32) -> !LowLFHE.glwe_ciphertext + // CHECK-NEXT: %[[V3:.*]] = call @allocate_plaintext_list_u64(%[[V0]], %[[C1]]) : (index, i32) -> !LowLFHE.plaintext_list // CHECK-NEXT: %[[C2:.*]] = constant 16 : i64 - // CHECK-NEXT: %[[V4:.*]] = call @foreign_plaintext_list_u64(%[[V0]], %arg0, %[[C2]]) : (memref, tensor<16xi4>, i64) -> !LowLFHE.foreign_plaintext_list - // CHECK-NEXT: call @fill_plaintext_list_with_expansion_u64(%[[V0]], %[[V3]], %[[V4]]) : (memref, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -> () - // CHECK-NEXT: call @add_plaintext_list_glwe_ciphertext_u64(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) : (memref, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) -> () + // CHECK-NEXT: %[[V4:.*]] = call @runtime_foreign_plaintext_list_u64(%[[V0]], %arg0, %[[C2]]) : (index, tensor<16xi4>, i64) -> !LowLFHE.foreign_plaintext_list + // CHECK-NEXT: call @fill_plaintext_list_with_expansion_u64(%[[V0]], %[[V3]], %[[V4]]) : (index, !LowLFHE.plaintext_list, !LowLFHE.foreign_plaintext_list) -> () + // CHECK-NEXT: call @add_plaintext_list_glwe_ciphertext_u64(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) : (index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) -> () // CHECK-NEXT: return %[[V1]] : !LowLFHE.glwe_ciphertext %1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext return %1: !LowLFHE.glwe_ciphertext diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir index 753ec67f6..a9e684177 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir @@ -1,20 +1,15 @@ // RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s // CHECK-LABEL: module -// CHECK-NEXT: func private @keyswitch_lwe_u64(memref, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(memref, i32) -> !LowLFHE.lwe_ciphertext<1024,4> -// CHECK-NEXT: func private @allocate_lwe_keyswitch_key_u64(memref, i32, i32, i32, i32) -> !LowLFHE.lwe_key_switch_key +// CHECK-NEXT: func private @keyswitch_lwe_u64(index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) +// CHECK-NEXT: func private @allocate_lwe_ciphertext_u64(index, i32) -> !LowLFHE.lwe_ciphertext<1024,4> // CHECK-LABEL: func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> { - // CHECK-NEXT: %[[V0:.*]] = memref.alloca() : memref + // CHECK-NEXT: %[[ERR:.*]] = constant 0 : index // CHECK-NEXT: %[[C0:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[V0]], %[[C0]]) : (memref, i32) -> !LowLFHE.lwe_ciphertext<1024,4> - // CHECK-NEXT: %[[C1:.*]] = constant 3 : i32 - // CHECK-NEXT: %[[C2:.*]] = constant 2 : i32 - // CHECK-NEXT: %[[C3:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[C4:.*]] = constant 1 : i32 - // CHECK-NEXT: %[[V2:.*]] = call @allocate_lwe_keyswitch_key_u64(%0, %[[C1]], %[[C2]], %[[C3]], %[[C4]]) : (memref, i32, i32, i32, i32) -> !LowLFHE.lwe_key_switch_key - // CHECK-NEXT: call @keyswitch_lwe_u64(%[[V0]], %[[V2]], %[[V1]], %arg0) : (memref, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -> () + // CHECK-NEXT: %[[V1:.*]] = call @allocate_lwe_ciphertext_u64(%[[ERR]], %[[C0]]) : (index, i32) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V2:.*]] = call @getGlobalKeyswitchKey() : () -> !LowLFHE.lwe_key_switch_key + // CHECK-NEXT: call @keyswitch_lwe_u64(%[[ERR]], %[[V2]], %[[V1]], %arg0) : (index, !LowLFHE.lwe_key_switch_key, !LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.lwe_ciphertext<1024,4>) -> () // CHECK-NEXT: return %[[V1]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index b222d8c1e..ab757025f 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -4,7 +4,7 @@ func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> { // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%arg1) {k = 1 : i32, polynomialSize = 1024 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> - // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%1, %0) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> + // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir new file mode 100644 index 000000000..943d9fe1f --- /dev/null +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir @@ -0,0 +1,13 @@ +// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> +func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> { + // CHECK-NEXT: %[[TABLE:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1]> : tensor<16xi4> + // CHECK-NEXT: %[[V1:.*]] = "LowLFHE.glwe_from_table"(%[[TABLE]]) {k = 1 : i32, polynomialSize = 2048 : i32} : (tensor<16xi4>) -> !LowLFHE.glwe_ciphertext + // CHECK-NEXT: %[[V2:.*]] = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> + // CHECK-NEXT: %[[V3:.*]] = "LowLFHE.bootstrap_lwe"(%[[V2]], %[[V1]]) {baseLog = 4 : i32, k = 1 : i32, level = 5 : i32, polynomialSize = 2048 : i32} : (!LowLFHE.lwe_ciphertext<2048,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<2048,4> + // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<2048,4> + %tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi4> + %1 = "MidLFHE.apply_lookup_table"(%arg0, %tlu){k=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32}: (!MidLFHE.glwe<{2048,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{2048,1,64}{4}>) + return %1: !MidLFHE.glwe<{2048,1,64}{4}> +} \ No newline at end of file diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 14787edb4..86170ad6d 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -15,6 +15,18 @@ from zamalang import CompilerEngine (5, 7), 12, ), + ( + """ + func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { + // 0..128 shifted << 55 + %tlu = std.constant dense<[0, 36028797018963968, 72057594037927936, 108086391056891904, 144115188075855872, 180143985094819840, 216172782113783808, 252201579132747776, 288230376151711744, 324259173170675712, 360287970189639680, 396316767208603648, 432345564227567616, 468374361246531584, 504403158265495552, 540431955284459520, 576460752303423488, 612489549322387456, 648518346341351424, 684547143360315392, 720575940379279360, 756604737398243328, 792633534417207296, 828662331436171264, 864691128455135232, 900719925474099200, 936748722493063168, 972777519512027136, 1008806316530991104, 1044835113549955072, 1080863910568919040, 1116892707587883008, 1152921504606846976, 1188950301625810944, 1224979098644774912, 1261007895663738880, 1297036692682702848, 1333065489701666816, 1369094286720630784, 1405123083739594752, 1441151880758558720, 1477180677777522688, 1513209474796486656, 1549238271815450624, 1585267068834414592, 1621295865853378560, 1657324662872342528, 1693353459891306496, 1729382256910270464, 1765411053929234432, 1801439850948198400, 1837468647967162368, 1873497444986126336, 1909526242005090304, 1945555039024054272, 1981583836043018240, 2017612633061982208, 2053641430080946176, 2089670227099910144, 2125699024118874112, 2161727821137838080, 2197756618156802048, 2233785415175766016, 2269814212194729984, 2305843009213693952, 2341871806232657920, 2377900603251621888, 2413929400270585856, 2449958197289549824, 2485986994308513792, 2522015791327477760, 2558044588346441728, 2594073385365405696, 2630102182384369664, 2666130979403333632, 2702159776422297600, 2738188573441261568, 2774217370460225536, 2810246167479189504, 2846274964498153472, 2882303761517117440, 2918332558536081408, 2954361355555045376, 2990390152574009344, 3026418949592973312, 3062447746611937280, 3098476543630901248, 3134505340649865216, 3170534137668829184, 3206562934687793152, 3242591731706757120, 3278620528725721088, 3314649325744685056, 3350678122763649024, 3386706919782612992, 3422735716801576960, 3458764513820540928, 3494793310839504896, 3530822107858468864, 3566850904877432832, 3602879701896396800, 3638908498915360768, 3674937295934324736, 3710966092953288704, 3746994889972252672, 3783023686991216640, 3819052484010180608, 3855081281029144576, 3891110078048108544, 3927138875067072512, 3963167672086036480, 3999196469105000448, 4035225266123964416, 4071254063142928384, 4107282860161892352, 4143311657180856320, 4179340454199820288, 4215369251218784256, 4251398048237748224, 4287426845256712192, 4323455642275676160, 4359484439294640128, 4395513236313604096, 4431542033332568064, 4467570830351532032, 4503599627370496000, 4539628424389459968, 4575657221408423936]> : tensor<128xi64> + %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> + } + """, + (5,), + 5, + ), ], ) def test_compile_and_run(mlir_input, args, expected_result): diff --git a/compiler/tests/python/test_round_trip.py b/compiler/tests/python/test_round_trip.py index 4f35c48e9..32baf32fa 100644 --- a/compiler/tests/python/test_round_trip.py +++ b/compiler/tests/python/test_round_trip.py @@ -25,6 +25,13 @@ VALID_INPUTS = [ return %1 : !HLFHE.eint<2> } """, + """ + func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { + %tlu = std.constant dense<[0, 36028797018963968, 72057594037927936, 108086391056891904, 144115188075855872, 180143985094819840, 216172782113783808, 252201579132747776, 288230376151711744, 324259173170675712, 360287970189639680, 396316767208603648, 432345564227567616, 468374361246531584, 504403158265495552, 540431955284459520, 576460752303423488, 612489549322387456, 648518346341351424, 684547143360315392, 720575940379279360, 756604737398243328, 792633534417207296, 828662331436171264, 864691128455135232, 900719925474099200, 936748722493063168, 972777519512027136, 1008806316530991104, 1044835113549955072, 1080863910568919040, 1116892707587883008, 1152921504606846976, 1188950301625810944, 1224979098644774912, 1261007895663738880, 1297036692682702848, 1333065489701666816, 1369094286720630784, 1405123083739594752, 1441151880758558720, 1477180677777522688, 1513209474796486656, 1549238271815450624, 1585267068834414592, 1621295865853378560, 1657324662872342528, 1693353459891306496, 1729382256910270464, 1765411053929234432, 1801439850948198400, 1837468647967162368, 1873497444986126336, 1909526242005090304, 1945555039024054272, 1981583836043018240, 2017612633061982208, 2053641430080946176, 2089670227099910144, 2125699024118874112, 2161727821137838080, 2197756618156802048, 2233785415175766016, 2269814212194729984, 2305843009213693952, 2341871806232657920, 2377900603251621888, 2413929400270585856, 2449958197289549824, 2485986994308513792, 2522015791327477760, 2558044588346441728, 2594073385365405696, 2630102182384369664, 2666130979403333632, 2702159776422297600, 2738188573441261568, 2774217370460225536, 2810246167479189504, 2846274964498153472, 2882303761517117440, 2918332558536081408, 2954361355555045376, 2990390152574009344, 3026418949592973312, 3062447746611937280, 3098476543630901248, 3134505340649865216, 3170534137668829184, 3206562934687793152, 3242591731706757120, 3278620528725721088, 3314649325744685056, 3350678122763649024, 3386706919782612992, 3422735716801576960, 3458764513820540928, 3494793310839504896, 3530822107858468864, 3566850904877432832, 3602879701896396800, 3638908498915360768, 3674937295934324736, 3710966092953288704, 3746994889972252672, 3783023686991216640, 3819052484010180608, 3855081281029144576, 3891110078048108544, 3927138875067072512, 3963167672086036480, 3999196469105000448, 4035225266123964416, 4071254063142928384, 4107282860161892352, 4143311657180856320, 4179340454199820288, 4215369251218784256, 4251398048237748224, 4287426845256712192, 4323455642275676160, 4359484439294640128, 4395513236313604096, 4431542033332568064, 4467570830351532032, 4503599627370496000, 4539628424389459968, 4575657221408423936]> : tensor<128xi64> + %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> + } + """, ] INVALID_INPUTS = [ diff --git a/compiler/tests/unittest/end_to_end_jit_test.cc b/compiler/tests/unittest/end_to_end_jit_test.cc index f46936daa..7c9dbf171 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.cc +++ b/compiler/tests/unittest/end_to_end_jit_test.cc @@ -410,4 +410,20 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>, uint64_t res; ASSERT_LLVM_ERROR(argument->getResult(0, res)); ASSERT_EQ(res, 14); +} + +TEST(CompileAndRunTLU, tlu) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { + %tlu = std.constant dense<[0, 36028797018963968, 72057594037927936, 108086391056891904, 144115188075855872, 180143985094819840, 216172782113783808, 252201579132747776, 288230376151711744, 324259173170675712, 360287970189639680, 396316767208603648, 432345564227567616, 468374361246531584, 504403158265495552, 540431955284459520, 576460752303423488, 612489549322387456, 648518346341351424, 684547143360315392, 720575940379279360, 756604737398243328, 792633534417207296, 828662331436171264, 864691128455135232, 900719925474099200, 936748722493063168, 972777519512027136, 1008806316530991104, 1044835113549955072, 1080863910568919040, 1116892707587883008, 1152921504606846976, 1188950301625810944, 1224979098644774912, 1261007895663738880, 1297036692682702848, 1333065489701666816, 1369094286720630784, 1405123083739594752, 1441151880758558720, 1477180677777522688, 1513209474796486656, 1549238271815450624, 1585267068834414592, 1621295865853378560, 1657324662872342528, 1693353459891306496, 1729382256910270464, 1765411053929234432, 1801439850948198400, 1837468647967162368, 1873497444986126336, 1909526242005090304, 1945555039024054272, 1981583836043018240, 2017612633061982208, 2053641430080946176, 2089670227099910144, 2125699024118874112, 2161727821137838080, 2197756618156802048, 2233785415175766016, 2269814212194729984, 2305843009213693952, 2341871806232657920, 2377900603251621888, 2413929400270585856, 2449958197289549824, 2485986994308513792, 2522015791327477760, 2558044588346441728, 2594073385365405696, 2630102182384369664, 2666130979403333632, 2702159776422297600, 2738188573441261568, 2774217370460225536, 2810246167479189504, 2846274964498153472, 2882303761517117440, 2918332558536081408, 2954361355555045376, 2990390152574009344, 3026418949592973312, 3062447746611937280, 3098476543630901248, 3134505340649865216, 3170534137668829184, 3206562934687793152, 3242591731706757120, 3278620528725721088, 3314649325744685056, 3350678122763649024, 3386706919782612992, 3422735716801576960, 3458764513820540928, 3494793310839504896, 3530822107858468864, 3566850904877432832, 3602879701896396800, 3638908498915360768, 3674937295934324736, 3710966092953288704, 3746994889972252672, 3783023686991216640, 3819052484010180608, 3855081281029144576, 3891110078048108544, 3927138875067072512, 3963167672086036480, 3999196469105000448, 4035225266123964416, 4071254063142928384, 4107282860161892352, 4143311657180856320, 4179340454199820288, 4215369251218784256, 4251398048237748224, 4287426845256712192, 4323455642275676160, 4359484439294640128, 4395513236313604096, 4431542033332568064, 4467570830351532032, 4503599627370496000, 4539628424389459968, 4575657221408423936]> : tensor<128xi64> + %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> +} +)XXX"; + ASSERT_FALSE(engine.compile(mlirStr)); + auto maybeResult = engine.run({5}); + ASSERT_TRUE((bool)maybeResult); + uint64_t result = maybeResult.get(); + ASSERT_EQ(result, 5); } \ No newline at end of file