cleanup(python): Cleanup python bindings after refactoring of CompilerEngine

This commit is contained in:
Quentin Bourgerie
2022-03-16 14:39:31 +01:00
parent 8867d313ee
commit 1b984f5119
2 changed files with 43 additions and 130 deletions

View File

@@ -3,7 +3,7 @@ import tempfile
import pytest
import numpy as np
from concrete.compiler import CompilerEngine, library
from concrete.compiler import CompilerEngine
from lib.Bindings.Python.concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
from lib.Bindings.Python.concrete.compiler import ClientSupport
from lib.Bindings.Python.concrete.compiler import KeySetCache
@@ -313,7 +313,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
engine.run(*args)
@pytest.mark.parametrize(
@ pytest.mark.parametrize(
"mlir_input, args, expected_result, tab_size",
[
pytest.param(
@@ -338,7 +338,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
@pytest.mark.parametrize(
@ pytest.mark.parametrize(
"mlir_input",
[
pytest.param(
@@ -359,55 +359,3 @@ def test_compile_invalid(mlir_input):
with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"):
engine.compile_fhe(
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
MODULE_1 = """
func @test1()
{
return
}
"""
MODULE_2 = """
func @test2()
{
return
}
"""
LIB_PATH = './test_library_generation.so'
@pytest.mark.parametrize(
'mlir_modules',
[
pytest.param(MODULE_1, id='1 module'),
pytest.param([MODULE_1, MODULE_2], id='2 modules'),
pytest.param(iter([MODULE_1, MODULE_2]), id='iterable'),
],
)
def test_library_generation(mlir_modules):
library_path = library(LIB_PATH, mlir_modules)
assert os.path.exists(library_path)
@pytest.mark.parametrize(
'mlir_modules',
[
pytest.param(bytes(MODULE_1, encoding='utf-8'), id='bytes vs str'),
pytest.param(None, id='not iterable'),
pytest.param([None], id='not str'),
],
)
def test_library_generation_type_error(mlir_modules):
with pytest.raises(TypeError):
library(LIB_PATH, mlir_modules)
def test_library_call():
module = """
func @test(%a: i8) -> i8
{
return %a : i8
}
"""
from ctypes import CDLL
lib = CDLL(library(LIB_PATH, module))
assert lib.test(13) == 13