feat(python): get path to diff artifacts

- path to client parameters file
- path to shared library
This commit is contained in:
youben11
2022-04-29 15:09:33 +01:00
committed by Ayoub Benaissa
parent 211241fcb2
commit 5f1a539505
7 changed files with 79 additions and 4 deletions

View File

@@ -1,4 +1,6 @@
import pytest
import os.path
import shutil
import numpy as np
from concrete.compiler import (
JITSupport,
@@ -167,19 +169,40 @@ def test_jit_compile_and_run(mlir_input, args, expected_result, keyset_cache):
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_and_run(mlir_input, args, expected_result, keyset_cache):
engine = LibrarySupport.new("./py_test_lib_compile_and_run")
artifact_dir = "./py_test_lib_compile_and_run"
engine = LibrarySupport.new(artifact_dir)
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache)
shutil.rmtree(artifact_dir)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_cache):
engine = LibrarySupport.new("./test_lib_compile_reload_and_run")
artifact_dir = "./test_lib_compile_reload_and_run"
engine = LibrarySupport.new(artifact_dir)
# Here don't save compilation result, reload
engine.compile(mlir_input)
compilation_result = engine.reload()
result = run(engine, args, compilation_result, keyset_cache)
# Check result
assert_result(result, expected_result)
shutil.rmtree(artifact_dir)
def test_lib_compilation_artifacts():
mlir_str = """
func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
"""
artifact_dir = "./test_artifacts"
engine = LibrarySupport.new(artifact_dir)
engine.compile(mlir_str)
assert os.path.exists(engine.get_client_parameters_path())
assert os.path.exists(engine.get_shared_lib_path())
shutil.rmtree(artifact_dir)
assert not os.path.exists(engine.get_client_parameters_path())
assert not os.path.exists(engine.get_shared_lib_path())
def test_lib_compile_and_run_p_error(keyset_cache):

View File

@@ -1,4 +1,5 @@
import pytest
import shutil
import numpy as np
from concrete.compiler import (
JITSupport,
@@ -143,7 +144,9 @@ def test_jit_compile_and_run_with_serialization(
def test_lib_compile_and_run_with_serialization(
mlir_input, args, expected_result, keyset_cache
):
engine = LibrarySupport.new("./py_test_lib_compile_and_run")
artifact_dir = "./py_test_lib_compile_and_run"
engine = LibrarySupport.new(artifact_dir)
compile_run_assert_with_serialization(
engine, mlir_input, args, expected_result, keyset_cache
)
shutil.rmtree(artifact_dir)