refactor(compiler): Refactor CompilerEngine and related classes

This commit contains several incremental improvements towards a clear
interface for lambdas:

  - Unification of static and JIT compilation by using the static
    compilation path of `CompilerEngine` within a new subclass
    `JitCompilerEngine`.

  - Clear ownership for compilation artefacts through
    `CompilationContext`, making it impossible to destroy objects used
    directly or indirectly before destruction of their users.

  - Clear interface for lambdas generated by the compiler through
    `JitCompilerEngine::Lambda` with a templated call operator,
    encapsulating otherwise manual orchestration of `CompilerEngine`,
    `JITLambda`, and `CompilerEngine::Argument`.

  - Improved error handling through `llvm::Expected<T>` and proper
    error checking following the conventions for `llvm::Expected<T>`
    and `llvm::Error`.

Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
Andi Drebes
2021-10-18 15:38:12 +02:00
parent d738104c4b
commit 1187cfbd62
61 changed files with 1690 additions and 997 deletions

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi64>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe --assume-max-manp=10 --assume-max-eint-precision=2 2>&1| FileCheck %s
// CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !MidLFHE.glwe<{1024,1,64}{4}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4>
func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
// RUN: zamacompiler --split-input-file --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
func @single_zero() -> !HLFHE.eint<2>
{

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// Incompatible shapes
func @dot_incompatible_shapes(

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: eint support only precision in ]0;7]
func @test(%arg0: !HLFHE.eint<8>) {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: eint support only precision in ]0;7]
func @test(%arg0: !HLFHE.eint<0>) {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument.
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @zero() -> !HLFHE.eint<2>
func @zero() -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1 | FileCheck %s
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s
//CHECK: #map0 = affine_map<(d0) -> (d0)>
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>
func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen
func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// GLWE p parameter result
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// GLWE p parameter
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// Bad dimension of the lookup table
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// GLWE p parameter
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
// GLWE p parameter
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --entry-dialect=midlfhe --action=roundtrip 2>&1| FileCheck %s
// RUN: zamacompiler %s --action=roundtrip 2>&1| FileCheck %s
// CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -56,7 +56,7 @@ def test_compile_and_run(mlir_input, args, expected_result):
def test_compile_and_run_invalid_arg_number(mlir_input, args):
engine = CompilerEngine()
engine.compile_fhe(mlir_input)
with pytest.raises(RuntimeError, match=r"failed pushing integer argument"):
with pytest.raises(ValueError, match=r"wrong number of arguments"):
engine.run(*args)

View File

@@ -1,8 +1,11 @@
#include <cstdint>
#include <gtest/gtest.h>
#include <type_traits>
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/JitCompilerEngine.h"
mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7};
mlir::zamalang::V0FHEConstraint defaultV0Constraints = {10, 7};
#define ASSERT_LLVM_ERROR(err) \
if (err) { \
@@ -10,384 +13,405 @@ mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7};
ASSERT_TRUE(false); \
}
// Checks that the value `val` is not in an error state. Returns
// `true` if the test passes, otherwise `false`.
template <typename T>
static bool assert_expected_success(llvm::Expected<T> &val) {
if (!((bool)val)) {
llvm::errs() << llvm::toString(std::move(val.takeError()));
return false;
}
return true;
}
// Checks that the value `val` is not in an error state. Returns
// `true` if the test passes, otherwise `false`.
template <typename T>
static bool assert_expected_success(llvm::Expected<T> &&val) {
return assert_expected_success(val);
}
// Checks that the value `val` of type `llvm::Expected<T>` is not in
// an error state.
#define ASSERT_EXPECTED_SUCCESS(val) \
do { \
if (!assert_expected_success(val)) \
GTEST_FATAL_FAILURE_("Expected<T> contained in error state"); \
} while (0)
// Checks that the value `val` is not in an error state and is equal
// to the value given in `exp`. Returns `true` if the test passes,
// otherwise `false`.
template <typename T, typename V>
static bool assert_expected_value(llvm::Expected<T> &val, const V &exp) {
if (!assert_expected_success(val))
return false;
if (!(val.get() == static_cast<T>(exp))) {
llvm::errs() << "Expected value " << exp << ", but got " << val.get()
<< "\n";
return false;
}
return true;
}
// Checks that the value `val` is not in an error state and is equal
// to the value given in `exp`. Returns `true` if the test passes,
// otherwise `false`.
template <typename T, typename V>
static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
return assert_expected_value(val, exp);
}
// Checks that the value `val` of type `llvm::Expected<T>` is not in
// an error state and is equal to the value of type `T` given in
// `exp`.
#define ASSERT_EXPECTED_VALUE(val, exp) \
do { \
if (!assert_expected_value(val, exp)) { \
GTEST_FATAL_FAILURE_("Expected<T> with wrong value"); \
} \
} while (0)
// Jit-compiles the function specified by `func` from `src` and
// returns the corresponding lambda. Any compilation errors are caught
// and reult in abnormal termination.
template <typename F>
mlir::zamalang::JitCompilerEngine::Lambda
internalCheckedJit(F checkfunc, llvm::StringRef src,
llvm::StringRef func = "main",
bool useDefaultFHEConstraints = false) {
mlir::zamalang::JitCompilerEngine engine;
if (useDefaultFHEConstraints)
engine.setFHEConstraints(defaultV0Constraints);
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
engine.buildLambda(src, func);
checkfunc(lambdaOrErr);
return std::move(*lambdaOrErr);
}
// Shorthands to create integer literals of a specific type
uint8_t operator"" _u8(unsigned long long int v) { return v; }
uint16_t operator"" _u16(unsigned long long int v) { return v; }
uint32_t operator"" _u32(unsigned long long int v) { return v; }
uint64_t operator"" _u64(unsigned long long int v) { return v; }
// Evaluates to the number of elements of a statically initialized
// array
#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof(arr[0]))
// Wrapper around `internalCheckedJit` that causes
// `ASSERT_EXPECTED_SUCCESS` to use the file and line number of the
// caller instead of `internalCheckedJit`.
#define checkedJit(...) \
internalCheckedJit( \
[](llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> &lambda) { \
ASSERT_EXPECTED_SUCCESS(lambda); \
}, \
__VA_ARGS__)
TEST(CompileAndRunHLFHE, add_eint) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
)XXX";
ASSERT_FALSE(engine.compile(mlirStr));
auto maybeResult = engine.run({1, 2});
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
ASSERT_EQ(result, 3);
)XXX");
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), 3);
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), 9);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), 2);
}
// Same as CompileAndRunHLFHE::add_eint above, but using
// `LambdaArgument` instances
TEST(CompileAndRunHLFHE, add_eint_lambda_argument) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
)XXX");
mlir::zamalang::IntLambdaArgument<> ila1(1);
mlir::zamalang::IntLambdaArgument<> ila2(2);
mlir::zamalang::IntLambdaArgument<> ila7(7);
mlir::zamalang::IntLambdaArgument<> ila9(9);
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila2}), 3);
ASSERT_EXPECTED_VALUE(lambda({&ila7, &ila9}), 16);
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila7}), 8);
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila9}), 10);
ASSERT_EXPECTED_VALUE(lambda({&ila2, &ila7}), 9);
}
TEST(CompileAndRunHLFHE, add_u64) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: i64, %arg1: i64) -> i64 {
%1 = addi %arg0, %arg1 : i64
return %1: i64
}
)XXX",
"main", true);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), (uint64_t)3);
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), (uint64_t)9);
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), (uint64_t)2);
}
TEST(CompileAndRunTensorStd, extract_64) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi64>, %i: index) -> i64{
%c = tensor.extract %t[%i] : tensor<10xi64>
return %c : i64
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF,
0,
8978,
2587490,
90,
197864,
698735,
72132,
87474,
42};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
static uint64_t t_arg[] = {0xFFFFFFFFFFFFFFFF,
0,
8978,
2587490,
90,
197864,
698735,
72132,
87474,
42};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorStd, extract_32) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi32>, %i: index) -> i32{
%c = tensor.extract %t[%i] : tensor<10xi32>
return %c : i32
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90,
197864, 698735, 72132, 87474, 42};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
)XXX",
"main", "true");
static uint32_t t_arg[] = {0xFFFFFFFF, 0, 8978, 2587490, 90,
197864, 698735, 72132, 87474, 42};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
// Same as `CompileAndRunTensorStd::extract_32` above, but using
// `LambdaArgument` instances
TEST(CompileAndRunTensorStd, extract_32_lambda_argument) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi32>, %i: index) -> i32{
%c = tensor.extract %t[%i] : tensor<10xi32>
return %c : i32
}
)XXX",
"main", "true");
static std::vector<uint32_t> t_arg{0xFFFFFFFF, 0, 8978, 2587490, 90,
197864, 698735, 72132, 87474, 42};
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint32_t>>
tla(t_arg);
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) {
mlir::zamalang::IntLambdaArgument<size_t> idx(i);
ASSERT_EXPECTED_VALUE(lambda({&tla, &idx}), t_arg[i]);
}
}
TEST(CompileAndRunTensorStd, extract_16) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi16>, %i: index) -> i16{
%c = tensor.extract %t[%i] : tensor<10xi16>
return %c : i16
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227,
63269, 36435, 52380, 7401, 13313};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
uint16_t t_arg[] = {0xFFFF, 0, 59589, 47826, 16227,
63269, 36435, 52380, 7401, 13313};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorStd, extract_8) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi8>, %i: index) -> i8{
%c = tensor.extract %t[%i] : tensor<10xi8>
return %c : i8
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
static uint8_t t_arg[] = {0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorStd, extract_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi5>, %i: index) -> i5{
%c = tensor.extract %t[%i] : tensor<10xi5>
return %c : i5
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorStd, extract_1) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi1>, %i: index) -> i1{
%c = tensor.extract %t[%i] : tensor<10xi1>
return %c : i1
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX",
"main", "true");
static uint8_t t_arg[] = {0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorEncrypted, extract_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
%c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
return %c : !HLFHE.eint<5>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i]);
}
)XXX");
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
}
TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) ->
!HLFHE.eint<5>{
%ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
%tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>>
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5>
return %c : !HLFHE.eint<5>
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) ->
!HLFHE.eint<5> return %c : !HLFHE.eint<5>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
for (size_t j = 0; j < size; j++) {
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Set the %i argument
ASSERT_LLVM_ERROR(argument->setArg(1, i));
// Set the %j argument
ASSERT_LLVM_ERROR(argument->setArg(2, j));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, t_arg[i] + t_arg[j]);
}
}
)XXX");
static uint8_t t_arg[] = {3, 0, 7, 12, 14, 6, 5, 4, 1, 2};
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
for (size_t j = 0; j < ARRAY_SIZE(t_arg); j++)
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i, j),
t_arg[i] + t_arg[j]);
}
TEST(CompileAndRunTensorEncrypted, dim_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{
%c0 = constant 0 : index
%c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>>
return %c : index
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res = 0;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, size);
)XXX");
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg)), ARRAY_SIZE(t_arg));
}
TEST(CompileAndRunTensorEncrypted, from_elements_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
%t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>>
return %t: tensor<1x!HLFHE.eint<5>>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the %t argument
ASSERT_LLVM_ERROR(argument->setArg(0, 10));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
size_t size_res = 1;
uint64_t t_res[size_res];
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
ASSERT_EQ(t_res[0], 10);
)XXX");
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>(10_u64);
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (size_t)1);
ASSERT_EQ(res->at(0), 10_u64);
}
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
%c_0 = constant 0 : index
%c_1 = constant 1 : index
%a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>>
%b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>>
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
%aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
%bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
%out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) ->
(!HLFHE.eint<5>) %aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>,
!HLFHE.eint<5>) -> (!HLFHE.eint<5>) %bplusb = "HLFHE.add_eint"(%b, %b):
(!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) %out =
tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
return %out: tensor<3x!HLFHE.eint<5>>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set the argument
const size_t in_size = 2;
uint8_t in[in_size] = {2, 16};
ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
const size_t size_res = 3;
uint64_t t_res[size_res];
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
ASSERT_EQ(t_res[0], in[0] + in[0]);
ASSERT_EQ(t_res[1], in[0] + in[1]);
ASSERT_EQ(t_res[2], in[1] + in[1]);
)XXX");
static uint8_t in[] = {2, 16};
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>(in, ARRAY_SIZE(in));
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (size_t)3);
ASSERT_EQ(res->at(0), (uint64_t)(in[0] + in[0]));
ASSERT_EQ(res->at(1), (uint64_t)(in[0] + in[1]));
ASSERT_EQ(res->at(2), (uint64_t)(in[1] + in[1]));
}
TEST(CompileAndRunTensorEncrypted, linalg_generic) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (0)>
func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc:
!HLFHE.eint<7>) -> !HLFHE.eint<7> {
%tacc = tensor.from_elements %acc : tensor<1x!HLFHE.eint<7>>
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) outs(%tacc : tensor<1x!HLFHE.eint<7>>) {
^bb0(%arg2: !HLFHE.eint<7>, %arg3: i8, %arg4: !HLFHE.eint<7>): // no predecessors
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> !HLFHE.eint<7>
%5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, !HLFHE.eint<7>) -> !HLFHE.eint<7>
linalg.yield %5 : !HLFHE.eint<7>
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types
= ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>)
outs(%tacc : tensor<1x!HLFHE.eint<7>>) { ^bb0(%arg2: !HLFHE.eint<7>, %arg3:
i8, %arg4: !HLFHE.eint<7>): // no predecessors
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) ->
!HLFHE.eint<7> %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>,
!HLFHE.eint<7>) -> !HLFHE.eint<7> linalg.yield %5 : !HLFHE.eint<7>
} -> tensor<1x!HLFHE.eint<7>>
%c0 = constant 0 : index
%ret = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<7>>
return %ret : !HLFHE.eint<7>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set arg0, arg1, acc
const size_t in_size = 2;
uint8_t arg0[in_size] = {2, 8};
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
uint8_t arg1[in_size] = {6, 8};
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
ASSERT_LLVM_ERROR(argument->setArg(2, 0));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, 76);
)XXX",
"main", "true");
static uint8_t arg0[] = {2, 8};
static uint8_t arg1[] = {6, 8};
llvm::Expected<uint64_t> res =
lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1), 0_u64);
ASSERT_EXPECTED_VALUE(res, 76);
}
TEST(CompileAndRunTensorEncrypted, dot_eint_int_7) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
%arg1: tensor<4xi8>) -> !HLFHE.eint<7>
{
@@ -395,77 +419,70 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
return %ret : !HLFHE.eint<7>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set arg0, arg1, acc
const size_t in_size = 4;
uint8_t arg0[in_size] = {0, 1, 2, 3};
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
uint8_t arg1[in_size] = {0, 1, 2, 3};
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, 14);
)XXX");
static uint8_t arg0[] = {0, 1, 2, 3};
static uint8_t arg1[] = {0, 1, 2, 3};
llvm::Expected<uint64_t> res =
lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1));
ASSERT_EXPECTED_VALUE(res, 14);
}
class CompileAndRunWithPrecision : public ::testing::TestWithParam<int> {
protected:
mlir::zamalang::CompilerEngine engine;
void compile(std::string mlirStr) { ASSERT_FALSE(engine.compile(mlirStr)); }
void run(std::vector<uint64_t> args, uint64_t expected) {
auto maybeResult = engine.run(args);
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
if (result == expected) {
ASSERT_TRUE(true);
} else {
// TODO: Better way to test the probability of exactness
llvm::errs() << "one fail retry\n";
maybeResult = engine.run(args);
ASSERT_TRUE((bool)maybeResult);
result = maybeResult.get();
ASSERT_EQ(result, expected);
}
}
};
class CompileAndRunWithPrecision : public ::testing::TestWithParam<int> {};
TEST_P(CompileAndRunWithPrecision, identity_func) {
int precision = GetParam();
uint64_t precision = GetParam();
std::ostringstream mlirProgram;
auto sizeOfTLU = 1 << precision;
mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision
<< ">) -> !HLFHE.eint<" << precision << "> { \n";
mlirProgram << " %tlu = std.constant dense<[0";
for (auto i = 1; i < sizeOfTLU; i++) {
mlirProgram << ", " << i;
}
mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n";
mlirProgram << " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): "
"(!HLFHE.eint<"
<< precision << ">, tensor<" << sizeOfTLU
<< "xi64>) -> (!HLFHE.eint<" << precision << ">)\n ";
mlirProgram << "return %1: !HLFHE.eint<" << precision << ">\n";
uint64_t sizeOfTLU = 1 << precision;
mlirProgram << "}\n";
llvm::errs() << mlirProgram.str();
compile(mlirProgram.str());
for (auto i = 0; i < sizeOfTLU; i++) {
run({(uint64_t)i}, i);
mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision
<< ">) -> !HLFHE.eint<" << precision << "> { \n"
<< " %tlu = std.constant dense<[0";
for (uint64_t i = 1; i < sizeOfTLU; i++)
mlirProgram << ", " << i;
mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n"
<< " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): "
<< "(!HLFHE.eint<" << precision << ">, tensor<" << sizeOfTLU
<< "xi64>) -> (!HLFHE.eint<" << precision << ">)\n "
<< "return %1: !HLFHE.eint<" << precision << ">\n"
<< "}\n";
mlir::zamalang::JitCompilerEngine::Lambda lambda =
checkedJit(mlirProgram.str());
if (precision == 7) {
// Test fails with a probability of 5% for a precision of 7. The
// probability of the test failing 5 times in a row is .05^5,
// which is less than 1:10,000 and comparable to the probability
// of failure for the other values.
static const int max_tries = 3;
for (uint64_t i = 0; i < sizeOfTLU; i++) {
for (int retry = 0; retry <= max_tries; retry++) {
if (retry == max_tries)
GTEST_FATAL_FAILURE_("Maximum number of tries exceeded");
llvm::Expected<uint64_t> val = lambda(i);
ASSERT_EXPECTED_SUCCESS(val);
if (*val == i)
break;
}
}
} else {
for (uint64_t i = 0; i < sizeOfTLU; i++)
ASSERT_EXPECTED_VALUE(lambda(i), i);
}
}
INSTANTIATE_TEST_CASE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision,
::testing::Values(1, 2, 3, 4, 5, 6, 7));
INSTANTIATE_TEST_SUITE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision,
::testing::Values(1, 2, 3, 4, 5, 6, 7));
TEST(TestHLFHEApplyLookupTable, multiple_precision) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> {
%tlu_7 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>
%tlu_3 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
@@ -474,45 +491,22 @@ func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> {
%a_plus_b = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<6>, !HLFHE.eint<6>) -> (!HLFHE.eint<6>)
return %a_plus_b: !HLFHE.eint<6>
}
)XXX";
ASSERT_FALSE(engine.compile(mlirStr));
uint64_t arg0 = 23;
uint64_t arg1 = 7;
uint64_t expected = 30;
auto maybeResult = engine.run({arg0, arg1});
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
ASSERT_EQ(result, expected);
)XXX");
ASSERT_EXPECTED_VALUE(lambda(23_u64, 7_u64), 30);
}
TEST(CompileAndRunTLU, random_func) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<6>) -> !HLFHE.eint<6> {
%tlu = std.constant dense<[16, 91, 16, 83, 80, 74, 21, 96, 1, 63, 49, 122, 76, 89, 74, 55, 109, 110, 103, 54, 105, 14, 66, 47, 52, 89, 7, 10, 73, 44, 119, 92, 25, 104, 123, 100, 108, 86, 29, 121, 118, 52, 107, 48, 34, 37, 13, 122, 107, 48, 74, 59, 96, 36, 50, 55, 120, 72, 27, 45, 12, 5, 96, 12]> : tensor<64xi64>
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<6>, tensor<64xi64>) -> (!HLFHE.eint<6>)
return %1: !HLFHE.eint<6>
}
)XXX";
ASSERT_FALSE(engine.compile(mlirStr));
// first value
auto maybeResult = engine.run({5});
ASSERT_TRUE((bool)maybeResult);
uint64_t result = maybeResult.get();
ASSERT_EQ(result, 74);
// second value
maybeResult = engine.run({62});
ASSERT_TRUE((bool)maybeResult);
result = maybeResult.get();
ASSERT_EQ(result, 96);
// edge value low
maybeResult = engine.run({0});
ASSERT_TRUE((bool)maybeResult);
result = maybeResult.get();
ASSERT_EQ(result, 16);
// edge value high
maybeResult = engine.run({63});
ASSERT_TRUE((bool)maybeResult);
result = maybeResult.get();
ASSERT_EQ(result, 12);
)XXX");
ASSERT_EXPECTED_VALUE(lambda(5_u64), 74);
ASSERT_EXPECTED_VALUE(lambda(62_u64), 96);
ASSERT_EXPECTED_VALUE(lambda(0_u64), 16);
ASSERT_EXPECTED_VALUE(lambda(63_u64), 12);
}