mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
cleanup(python): Cleanup python bindings after refactoring of CompilerEngine
This commit is contained in:
@@ -3,7 +3,7 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine, library
|
||||
from concrete.compiler import CompilerEngine
|
||||
from lib.Bindings.Python.concrete.compiler import JITCompilerSupport, LibraryCompilerSupport
|
||||
from lib.Bindings.Python.concrete.compiler import ClientSupport
|
||||
from lib.Bindings.Python.concrete.compiler import KeySetCache
|
||||
@@ -313,7 +313,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
engine.run(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result, tab_size",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -338,7 +338,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ pytest.mark.parametrize(
|
||||
"mlir_input",
|
||||
[
|
||||
pytest.param(
|
||||
@@ -359,55 +359,3 @@ def test_compile_invalid(mlir_input):
|
||||
with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"):
|
||||
engine.compile_fhe(
|
||||
mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH)
|
||||
|
||||
|
||||
MODULE_1 = """
|
||||
func @test1()
|
||||
{
|
||||
return
|
||||
}
|
||||
"""
|
||||
MODULE_2 = """
|
||||
func @test2()
|
||||
{
|
||||
return
|
||||
}
|
||||
"""
|
||||
LIB_PATH = './test_library_generation.so'
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'mlir_modules',
|
||||
[
|
||||
pytest.param(MODULE_1, id='1 module'),
|
||||
pytest.param([MODULE_1, MODULE_2], id='2 modules'),
|
||||
pytest.param(iter([MODULE_1, MODULE_2]), id='iterable'),
|
||||
],
|
||||
)
|
||||
def test_library_generation(mlir_modules):
|
||||
library_path = library(LIB_PATH, mlir_modules)
|
||||
assert os.path.exists(library_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'mlir_modules',
|
||||
[
|
||||
pytest.param(bytes(MODULE_1, encoding='utf-8'), id='bytes vs str'),
|
||||
pytest.param(None, id='not iterable'),
|
||||
pytest.param([None], id='not str'),
|
||||
],
|
||||
)
|
||||
def test_library_generation_type_error(mlir_modules):
|
||||
with pytest.raises(TypeError):
|
||||
library(LIB_PATH, mlir_modules)
|
||||
|
||||
|
||||
def test_library_call():
|
||||
module = """
|
||||
func @test(%a: i8) -> i8
|
||||
{
|
||||
return %a : i8
|
||||
}
|
||||
"""
|
||||
from ctypes import CDLL
|
||||
lib = CDLL(library(LIB_PATH, module))
|
||||
assert lib.test(13) == 13
|
||||
|
||||
Reference in New Issue
Block a user