refactor: rename [jit|library]lambdasupport to [jit|library]support

This commit is contained in:
youben11
2022-03-31 20:17:10 +01:00
committed by Ayoub Benaissa
parent 583d5edf00
commit 78def04fe5
19 changed files with 117 additions and 134 deletions

View File

@@ -3,9 +3,7 @@ import tempfile
import pytest
import numpy as np
from concrete.compiler import JITLambdaSupport, LibraryLambdaSupport
from concrete.compiler import ClientSupport
from concrete.compiler import KeySetCache
from concrete.compiler import JITSupport, LibrarySupport, ClientSupport, KeySetCache
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
@@ -244,19 +242,19 @@ end_to_end_fixture = [
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_jit_compile_and_run(mlir_input, args, expected_result):
engine = JITLambdaSupport.new()
engine = JITSupport.new()
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_and_run(mlir_input, args, expected_result):
engine = LibraryLambdaSupport.new("py_test_lib_compile_and_run")
engine = LibrarySupport.new("py_test_lib_compile_and_run")
compile_and_run(engine, mlir_input, args, expected_result)
@pytest.mark.parametrize("mlir_input, args, expected_result", end_to_end_fixture)
def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
engine = LibraryLambdaSupport.new("test_lib_compile_reload_and_run")
engine = LibrarySupport.new("test_lib_compile_reload_and_run")
# Here don't save compilation result, reload
engine.compile(mlir_input)
compilation_result = engine.reload()
@@ -293,7 +291,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result):
],
)
def test_compile_and_run_invalid_arg_number(mlir_input, args):
engine = JITLambdaSupport.new()
engine = JITSupport.new()
with pytest.raises(
RuntimeError, match=r"function has arity 2 but is applied to too many arguments"
):
@@ -318,7 +316,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
],
)
def test_compile_and_run_tlu(mlir_input, args, expected_result):
engine = JITLambdaSupport.new()
engine = JITSupport.new()
compile_and_run(engine, mlir_input, args, expected_result)
@@ -339,7 +337,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result):
],
)
def test_compile_invalid(mlir_input):
engine = JITLambdaSupport.new()
engine = JITSupport.new()
with pytest.raises(
RuntimeError, match=r"cannot find the function for generate client parameters"
):

View File

@@ -4,10 +4,7 @@ import tempfile
import pytest
import numpy as np
from concrete.compiler.client_support import ClientSupport
from concrete.compiler.compilation_options import CompilationOptions
from concrete.compiler.jit_lambda_support import JITLambdaSupport
from concrete.compiler.key_set_cache import KeySetCache
from concrete.compiler import ClientSupport, CompilationOptions, JITSupport, KeySetCache
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), "KeySetCache")
@@ -69,5 +66,5 @@ def compile_and_run(engine, mlir_input, args, expected_result):
],
)
def test_compile_and_run_parallel(mlir_input, args, expected_result):
engine = JITLambdaSupport.new()
engine = JITSupport.new()
compile_and_run(engine, mlir_input, args, expected_result)