mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor(compiler): Refactor CompilerEngine and related classes
This commit contains several incremental improvements towards a clear
interface for lambdas:
- Unification of static and JIT compilation by using the static
compilation path of `CompilerEngine` within a new subclass
`JitCompilerEngine`.
- Clear ownership for compilation artefacts through
`CompilationContext`, making it impossible to destroy objects used
directly or indirectly before destruction of their users.
- Clear interface for lambdas generated by the compiler through
`JitCompilerEngine::Lambda` with a templated call operator,
encapsulating otherwise manual orchestration of `CompilerEngine`,
`JITLambda`, and `CompilerEngine::Argument`.
- Improved error handling through `llvm::Expected<T>` and proper
error checking following the conventions for `llvm::Expected<T>`
and `llvm::Error`.
Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
|
||||
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi64>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe --assume-max-manp=10 --assume-max-eint-precision=2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: #map0 = affine_map<(d0) -> (d0)>
|
||||
// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
|
||||
func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
|
||||
func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !MidLFHE.glwe<{1024,1,64}{4}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4>
|
||||
func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
|
||||
func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
|
||||
func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
|
||||
// RUN: zamacompiler --split-input-file --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
|
||||
|
||||
func @single_zero() -> !HLFHE.eint<2>
|
||||
{
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// Incompatible shapes
|
||||
func @dot_incompatible_shapes(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: eint support only precision in ]0;7]
|
||||
func @test(%arg0: !HLFHE.eint<8>) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: eint support only precision in ]0;7]
|
||||
func @test(%arg0: !HLFHE.eint<0>) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals
|
||||
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals
|
||||
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1
|
||||
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals
|
||||
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument.
|
||||
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1
|
||||
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals
|
||||
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1
|
||||
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals
|
||||
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @zero() -> !HLFHE.eint<2>
|
||||
func @zero() -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1 | FileCheck %s
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s
|
||||
|
||||
//CHECK: #map0 = affine_map<(d0) -> (d0)>
|
||||
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>
|
||||
func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
|
||||
func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen
|
||||
func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// GLWE p parameter result
|
||||
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// GLWE p parameter
|
||||
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// Bad dimension of the lookup table
|
||||
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}>
|
||||
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// GLWE p parameter
|
||||
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// GLWE p parameter
|
||||
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=midlfhe --action=roundtrip 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --action=roundtrip 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -56,7 +56,7 @@ def test_compile_and_run(mlir_input, args, expected_result):
|
||||
def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input)
|
||||
with pytest.raises(RuntimeError, match=r"failed pushing integer argument"):
|
||||
with pytest.raises(ValueError, match=r"wrong number of arguments"):
|
||||
engine.run(*args)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
#include <cstdint>
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
|
||||
mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7};
|
||||
mlir::zamalang::V0FHEConstraint defaultV0Constraints = {10, 7};
|
||||
|
||||
#define ASSERT_LLVM_ERROR(err) \
|
||||
if (err) { \
|
||||
@@ -10,384 +13,405 @@ mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7};
|
||||
ASSERT_TRUE(false); \
|
||||
}
|
||||
|
||||
// Checks that the value `val` is not in an error state. Returns
|
||||
// `true` if the test passes, otherwise `false`.
|
||||
template <typename T>
|
||||
static bool assert_expected_success(llvm::Expected<T> &val) {
|
||||
if (!((bool)val)) {
|
||||
llvm::errs() << llvm::toString(std::move(val.takeError()));
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks that the value `val` is not in an error state. Returns
|
||||
// `true` if the test passes, otherwise `false`.
|
||||
template <typename T>
|
||||
static bool assert_expected_success(llvm::Expected<T> &&val) {
|
||||
return assert_expected_success(val);
|
||||
}
|
||||
|
||||
// Checks that the value `val` of type `llvm::Expected<T>` is not in
|
||||
// an error state.
|
||||
#define ASSERT_EXPECTED_SUCCESS(val) \
|
||||
do { \
|
||||
if (!assert_expected_success(val)) \
|
||||
GTEST_FATAL_FAILURE_("Expected<T> contained in error state"); \
|
||||
} while (0)
|
||||
|
||||
// Checks that the value `val` is not in an error state and is equal
|
||||
// to the value given in `exp`. Returns `true` if the test passes,
|
||||
// otherwise `false`.
|
||||
template <typename T, typename V>
|
||||
static bool assert_expected_value(llvm::Expected<T> &val, const V &exp) {
|
||||
if (!assert_expected_success(val))
|
||||
return false;
|
||||
|
||||
if (!(val.get() == static_cast<T>(exp))) {
|
||||
llvm::errs() << "Expected value " << exp << ", but got " << val.get()
|
||||
<< "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks that the value `val` is not in an error state and is equal
|
||||
// to the value given in `exp`. Returns `true` if the test passes,
|
||||
// otherwise `false`.
|
||||
template <typename T, typename V>
|
||||
static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
|
||||
return assert_expected_value(val, exp);
|
||||
}
|
||||
|
||||
// Checks that the value `val` of type `llvm::Expected<T>` is not in
|
||||
// an error state and is equal to the value of type `T` given in
|
||||
// `exp`.
|
||||
#define ASSERT_EXPECTED_VALUE(val, exp) \
|
||||
do { \
|
||||
if (!assert_expected_value(val, exp)) { \
|
||||
GTEST_FATAL_FAILURE_("Expected<T> with wrong value"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Jit-compiles the function specified by `func` from `src` and
|
||||
// returns the corresponding lambda. Any compilation errors are caught
|
||||
// and reult in abnormal termination.
|
||||
template <typename F>
|
||||
mlir::zamalang::JitCompilerEngine::Lambda
|
||||
internalCheckedJit(F checkfunc, llvm::StringRef src,
|
||||
llvm::StringRef func = "main",
|
||||
bool useDefaultFHEConstraints = false) {
|
||||
mlir::zamalang::JitCompilerEngine engine;
|
||||
|
||||
if (useDefaultFHEConstraints)
|
||||
engine.setFHEConstraints(defaultV0Constraints);
|
||||
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(src, func);
|
||||
|
||||
checkfunc(lambdaOrErr);
|
||||
|
||||
return std::move(*lambdaOrErr);
|
||||
}
|
||||
|
||||
// Shorthands to create integer literals of a specific type
|
||||
uint8_t operator"" _u8(unsigned long long int v) { return v; }
|
||||
uint16_t operator"" _u16(unsigned long long int v) { return v; }
|
||||
uint32_t operator"" _u32(unsigned long long int v) { return v; }
|
||||
uint64_t operator"" _u64(unsigned long long int v) { return v; }
|
||||
|
||||
// Evaluates to the number of elements of a statically initialized
|
||||
// array
|
||||
#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof(arr[0]))
|
||||
|
||||
// Wrapper around `internalCheckedJit` that causes
|
||||
// `ASSERT_EXPECTED_SUCCESS` to use the file and line number of the
|
||||
// caller instead of `internalCheckedJit`.
|
||||
#define checkedJit(...) \
|
||||
internalCheckedJit( \
|
||||
[](llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> &lambda) { \
|
||||
ASSERT_EXPECTED_SUCCESS(lambda); \
|
||||
}, \
|
||||
__VA_ARGS__)
|
||||
|
||||
TEST(CompileAndRunHLFHE, add_eint) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
auto maybeResult = engine.run({1, 2});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, 3);
|
||||
)XXX");
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), 3);
|
||||
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), 9);
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), 2);
|
||||
}
|
||||
|
||||
// Same as CompileAndRunHLFHE::add_eint above, but using
|
||||
// `LambdaArgument` instances
|
||||
TEST(CompileAndRunHLFHE, add_eint_lambda_argument) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
mlir::zamalang::IntLambdaArgument<> ila1(1);
|
||||
mlir::zamalang::IntLambdaArgument<> ila2(2);
|
||||
mlir::zamalang::IntLambdaArgument<> ila7(7);
|
||||
mlir::zamalang::IntLambdaArgument<> ila9(9);
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila2}), 3);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila7, &ila9}), 16);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila7}), 8);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila9}), 10);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila2, &ila7}), 9);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunHLFHE, add_u64) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: i64, %arg1: i64) -> i64 {
|
||||
%1 = addi %arg0, %arg1 : i64
|
||||
return %1: i64
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), (uint64_t)3);
|
||||
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), (uint64_t)9);
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), (uint64_t)2);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_64) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi64>, %i: index) -> i64{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi64>
|
||||
return %c : i64
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
const size_t size = 10;
|
||||
uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint64_t t_arg[] = {0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_32) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi32>, %i: index) -> i32{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi32>
|
||||
return %c : i32
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
const size_t size = 10;
|
||||
uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90,
|
||||
197864, 698735, 72132, 87474, 42};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
)XXX",
|
||||
"main", "true");
|
||||
static uint32_t t_arg[] = {0xFFFFFFFF, 0, 8978, 2587490, 90,
|
||||
197864, 698735, 72132, 87474, 42};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
// Same as `CompileAndRunTensorStd::extract_32` above, but using
|
||||
// `LambdaArgument` instances
|
||||
TEST(CompileAndRunTensorStd, extract_32_lambda_argument) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi32>, %i: index) -> i32{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi32>
|
||||
return %c : i32
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
static std::vector<uint32_t> t_arg{0xFFFFFFFF, 0, 8978, 2587490, 90,
|
||||
197864, 698735, 72132, 87474, 42};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint32_t>>
|
||||
tla(t_arg);
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) {
|
||||
mlir::zamalang::IntLambdaArgument<size_t> idx(i);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&tla, &idx}), t_arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_16) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi16>, %i: index) -> i16{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi16>
|
||||
return %c : i16
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
const size_t size = 10;
|
||||
uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227,
|
||||
63269, 36435, 52380, 7401, 13313};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
uint16_t t_arg[] = {0xFFFF, 0, 59589, 47826, 16227,
|
||||
63269, 36435, 52380, 7401, 13313};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_8) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi8>, %i: index) -> i8{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi8>
|
||||
return %c : i8
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint8_t t_arg[] = {0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi5>, %i: index) -> i5{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi5>
|
||||
return %c : i5
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_1) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi1>, %i: index) -> i1{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi1>
|
||||
return %c : i1
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint8_t t_arg[] = {0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, extract_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
|
||||
%c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
|
||||
return %c : !HLFHE.eint<5>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
)XXX");
|
||||
|
||||
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) ->
|
||||
!HLFHE.eint<5>{
|
||||
%ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
|
||||
%tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>>
|
||||
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5>
|
||||
return %c : !HLFHE.eint<5>
|
||||
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) ->
|
||||
!HLFHE.eint<5> return %c : !HLFHE.eint<5>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Set the %j argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(2, j));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i] + t_arg[j]);
|
||||
}
|
||||
}
|
||||
)XXX");
|
||||
|
||||
static uint8_t t_arg[] = {3, 0, 7, 12, 14, 6, 5, 4, 1, 2};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
for (size_t j = 0; j < ARRAY_SIZE(t_arg); j++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i, j),
|
||||
t_arg[i] + t_arg[j]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, dim_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{
|
||||
%c0 = constant 0 : index
|
||||
%c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>>
|
||||
return %c : index
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, size);
|
||||
)XXX");
|
||||
|
||||
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg)), ARRAY_SIZE(t_arg));
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, from_elements_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
|
||||
%t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>>
|
||||
return %t: tensor<1x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, 10));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
size_t size_res = 1;
|
||||
uint64_t t_res[size_res];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
|
||||
ASSERT_EQ(t_res[0], 10);
|
||||
)XXX");
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>(10_u64);
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
ASSERT_EQ(res->size(), (size_t)1);
|
||||
ASSERT_EQ(res->at(0), 10_u64);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
|
||||
%c_0 = constant 0 : index
|
||||
%c_1 = constant 1 : index
|
||||
%a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>>
|
||||
%b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>>
|
||||
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
|
||||
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) ->
|
||||
(!HLFHE.eint<5>) %aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>,
|
||||
!HLFHE.eint<5>) -> (!HLFHE.eint<5>) %bplusb = "HLFHE.add_eint"(%b, %b):
|
||||
(!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) %out =
|
||||
tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
|
||||
return %out: tensor<3x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the argument
|
||||
const size_t in_size = 2;
|
||||
uint8_t in[in_size] = {2, 16};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
const size_t size_res = 3;
|
||||
uint64_t t_res[size_res];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
|
||||
ASSERT_EQ(t_res[0], in[0] + in[0]);
|
||||
ASSERT_EQ(t_res[1], in[0] + in[1]);
|
||||
ASSERT_EQ(t_res[2], in[1] + in[1]);
|
||||
)XXX");
|
||||
|
||||
static uint8_t in[] = {2, 16};
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>(in, ARRAY_SIZE(in));
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (size_t)3);
|
||||
ASSERT_EQ(res->at(0), (uint64_t)(in[0] + in[0]));
|
||||
ASSERT_EQ(res->at(1), (uint64_t)(in[0] + in[1]));
|
||||
ASSERT_EQ(res->at(2), (uint64_t)(in[1] + in[1]));
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, linalg_generic) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
#map0 = affine_map<(d0) -> (d0)>
|
||||
#map1 = affine_map<(d0) -> (0)>
|
||||
func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc:
|
||||
!HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%tacc = tensor.from_elements %acc : tensor<1x!HLFHE.eint<7>>
|
||||
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) outs(%tacc : tensor<1x!HLFHE.eint<7>>) {
|
||||
^bb0(%arg2: !HLFHE.eint<7>, %arg3: i8, %arg4: !HLFHE.eint<7>): // no predecessors
|
||||
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> !HLFHE.eint<7>
|
||||
%5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, !HLFHE.eint<7>) -> !HLFHE.eint<7>
|
||||
linalg.yield %5 : !HLFHE.eint<7>
|
||||
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types
|
||||
= ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>)
|
||||
outs(%tacc : tensor<1x!HLFHE.eint<7>>) { ^bb0(%arg2: !HLFHE.eint<7>, %arg3:
|
||||
i8, %arg4: !HLFHE.eint<7>): // no predecessors
|
||||
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) ->
|
||||
!HLFHE.eint<7> %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>,
|
||||
!HLFHE.eint<7>) -> !HLFHE.eint<7> linalg.yield %5 : !HLFHE.eint<7>
|
||||
} -> tensor<1x!HLFHE.eint<7>>
|
||||
%c0 = constant 0 : index
|
||||
%ret = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<7>>
|
||||
return %ret : !HLFHE.eint<7>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set arg0, arg1, acc
|
||||
const size_t in_size = 2;
|
||||
uint8_t arg0[in_size] = {2, 8};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
|
||||
uint8_t arg1[in_size] = {6, 8};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(2, 0));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, 76);
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint8_t arg0[] = {2, 8};
|
||||
static uint8_t arg1[] = {6, 8};
|
||||
|
||||
llvm::Expected<uint64_t> res =
|
||||
lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1), 0_u64);
|
||||
|
||||
ASSERT_EXPECTED_VALUE(res, 76);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, dot_eint_int_7) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
|
||||
%arg1: tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
{
|
||||
@@ -395,77 +419,70 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
|
||||
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
return %ret : !HLFHE.eint<7>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set arg0, arg1, acc
|
||||
const size_t in_size = 4;
|
||||
uint8_t arg0[in_size] = {0, 1, 2, 3};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
|
||||
uint8_t arg1[in_size] = {0, 1, 2, 3};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, 14);
|
||||
)XXX");
|
||||
static uint8_t arg0[] = {0, 1, 2, 3};
|
||||
static uint8_t arg1[] = {0, 1, 2, 3};
|
||||
|
||||
llvm::Expected<uint64_t> res =
|
||||
lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1));
|
||||
|
||||
ASSERT_EXPECTED_VALUE(res, 14);
|
||||
}
|
||||
|
||||
class CompileAndRunWithPrecision : public ::testing::TestWithParam<int> {
|
||||
protected:
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
void compile(std::string mlirStr) { ASSERT_FALSE(engine.compile(mlirStr)); }
|
||||
void run(std::vector<uint64_t> args, uint64_t expected) {
|
||||
auto maybeResult = engine.run(args);
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
if (result == expected) {
|
||||
ASSERT_TRUE(true);
|
||||
} else {
|
||||
// TODO: Better way to test the probability of exactness
|
||||
llvm::errs() << "one fail retry\n";
|
||||
maybeResult = engine.run(args);
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
result = maybeResult.get();
|
||||
ASSERT_EQ(result, expected);
|
||||
}
|
||||
}
|
||||
};
|
||||
class CompileAndRunWithPrecision : public ::testing::TestWithParam<int> {};
|
||||
|
||||
TEST_P(CompileAndRunWithPrecision, identity_func) {
|
||||
int precision = GetParam();
|
||||
uint64_t precision = GetParam();
|
||||
std::ostringstream mlirProgram;
|
||||
auto sizeOfTLU = 1 << precision;
|
||||
mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision
|
||||
<< ">) -> !HLFHE.eint<" << precision << "> { \n";
|
||||
mlirProgram << " %tlu = std.constant dense<[0";
|
||||
for (auto i = 1; i < sizeOfTLU; i++) {
|
||||
mlirProgram << ", " << i;
|
||||
}
|
||||
mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n";
|
||||
mlirProgram << " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): "
|
||||
"(!HLFHE.eint<"
|
||||
<< precision << ">, tensor<" << sizeOfTLU
|
||||
<< "xi64>) -> (!HLFHE.eint<" << precision << ">)\n ";
|
||||
mlirProgram << "return %1: !HLFHE.eint<" << precision << ">\n";
|
||||
uint64_t sizeOfTLU = 1 << precision;
|
||||
|
||||
mlirProgram << "}\n";
|
||||
llvm::errs() << mlirProgram.str();
|
||||
compile(mlirProgram.str());
|
||||
for (auto i = 0; i < sizeOfTLU; i++) {
|
||||
run({(uint64_t)i}, i);
|
||||
mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision
|
||||
<< ">) -> !HLFHE.eint<" << precision << "> { \n"
|
||||
<< " %tlu = std.constant dense<[0";
|
||||
|
||||
for (uint64_t i = 1; i < sizeOfTLU; i++)
|
||||
mlirProgram << ", " << i;
|
||||
|
||||
mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n"
|
||||
<< " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): "
|
||||
<< "(!HLFHE.eint<" << precision << ">, tensor<" << sizeOfTLU
|
||||
<< "xi64>) -> (!HLFHE.eint<" << precision << ">)\n "
|
||||
<< "return %1: !HLFHE.eint<" << precision << ">\n"
|
||||
<< "}\n";
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda =
|
||||
checkedJit(mlirProgram.str());
|
||||
|
||||
if (precision == 7) {
|
||||
// Test fails with a probability of 5% for a precision of 7. The
|
||||
// probability of the test failing 5 times in a row is .05^5,
|
||||
// which is less than 1:10,000 and comparable to the probability
|
||||
// of failure for the other values.
|
||||
static const int max_tries = 3;
|
||||
|
||||
for (uint64_t i = 0; i < sizeOfTLU; i++) {
|
||||
for (int retry = 0; retry <= max_tries; retry++) {
|
||||
if (retry == max_tries)
|
||||
GTEST_FATAL_FAILURE_("Maximum number of tries exceeded");
|
||||
|
||||
llvm::Expected<uint64_t> val = lambda(i);
|
||||
ASSERT_EXPECTED_SUCCESS(val);
|
||||
|
||||
if (*val == i)
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < sizeOfTLU; i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(i), i);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision,
|
||||
::testing::Values(1, 2, 3, 4, 5, 6, 7));
|
||||
INSTANTIATE_TEST_SUITE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision,
|
||||
::testing::Values(1, 2, 3, 4, 5, 6, 7));
|
||||
|
||||
TEST(TestHLFHEApplyLookupTable, multiple_precision) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> {
|
||||
%tlu_7 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>
|
||||
%tlu_3 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
|
||||
@@ -474,45 +491,22 @@ func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> {
|
||||
%a_plus_b = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<6>, !HLFHE.eint<6>) -> (!HLFHE.eint<6>)
|
||||
return %a_plus_b: !HLFHE.eint<6>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
uint64_t arg0 = 23;
|
||||
uint64_t arg1 = 7;
|
||||
uint64_t expected = 30;
|
||||
auto maybeResult = engine.run({arg0, arg1});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, expected);
|
||||
)XXX");
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(23_u64, 7_u64), 30);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTLU, random_func) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<6>) -> !HLFHE.eint<6> {
|
||||
%tlu = std.constant dense<[16, 91, 16, 83, 80, 74, 21, 96, 1, 63, 49, 122, 76, 89, 74, 55, 109, 110, 103, 54, 105, 14, 66, 47, 52, 89, 7, 10, 73, 44, 119, 92, 25, 104, 123, 100, 108, 86, 29, 121, 118, 52, 107, 48, 34, 37, 13, 122, 107, 48, 74, 59, 96, 36, 50, 55, 120, 72, 27, 45, 12, 5, 96, 12]> : tensor<64xi64>
|
||||
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<6>, tensor<64xi64>) -> (!HLFHE.eint<6>)
|
||||
return %1: !HLFHE.eint<6>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
// first value
|
||||
auto maybeResult = engine.run({5});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, 74);
|
||||
// second value
|
||||
maybeResult = engine.run({62});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
result = maybeResult.get();
|
||||
ASSERT_EQ(result, 96);
|
||||
// edge value low
|
||||
maybeResult = engine.run({0});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
result = maybeResult.get();
|
||||
ASSERT_EQ(result, 16);
|
||||
// edge value high
|
||||
maybeResult = engine.run({63});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
result = maybeResult.get();
|
||||
ASSERT_EQ(result, 12);
|
||||
)XXX");
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(5_u64), 74);
|
||||
ASSERT_EXPECTED_VALUE(lambda(62_u64), 96);
|
||||
ASSERT_EXPECTED_VALUE(lambda(0_u64), 16);
|
||||
ASSERT_EXPECTED_VALUE(lambda(63_u64), 12);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user