mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test: create default_compilation_configuration fixture
- update test code and use it where appropriate - remove duplicate tests that lacked correctness verification
This commit is contained in:
@@ -224,7 +224,12 @@ def check_is_good_execution(compiler_engine, function, args):
|
||||
)
|
||||
|
||||
|
||||
def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input_ranges):
|
||||
def subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc,
|
||||
upper_function,
|
||||
input_ranges,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def get_function(ufunc, upper_function):
|
||||
@@ -242,6 +247,7 @@ def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
@@ -249,7 +255,13 @@ def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
|
||||
def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, input_ranges):
|
||||
def subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
upper_function,
|
||||
c,
|
||||
input_ranges,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def get_function(ufunc, upper_function):
|
||||
@@ -267,6 +279,7 @@ def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, i
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
@@ -278,54 +291,86 @@ def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, i
|
||||
"ufunc",
|
||||
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 2],
|
||||
)
|
||||
def test_binary_ufunc_operations(ufunc):
|
||||
def test_binary_ufunc_operations(ufunc, default_compilation_configuration):
|
||||
"""Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
if ufunc in [numpy.power, numpy.float_power]:
|
||||
# Need small constants to keep results really small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_one, 3, ((0, 4), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_one,
|
||||
3,
|
||||
((0, 4), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.lcm, numpy.left_shift]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_one, 3, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_one,
|
||||
3,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
else:
|
||||
# General case
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_one, 41, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_one,
|
||||
41,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
if ufunc in [numpy.power, numpy.float_power]:
|
||||
# Need small constants to keep results really small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 4), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
2,
|
||||
((0, 4), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 31, ((1, 5), (1, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
31,
|
||||
((1, 5), (1, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.lcm, numpy.left_shift]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
2,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.ldexp]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
2,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
else:
|
||||
# General case
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 42, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
42,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ufunc", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1]
|
||||
)
|
||||
def test_unary_ufunc_operations(ufunc):
|
||||
def test_unary_ufunc_operations(ufunc, default_compilation_configuration):
|
||||
"""Test unary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
if ufunc in [
|
||||
numpy.degrees,
|
||||
@@ -333,14 +378,20 @@ def test_unary_ufunc_operations(ufunc):
|
||||
]:
|
||||
# Need to reduce the output value, to avoid to need too much precision
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_which_has_large_outputs, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f_which_has_large_outputs,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.negative,
|
||||
]:
|
||||
# Need to turn the input into a float
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_with_float_inputs, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f_with_float_inputs,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.invert,
|
||||
@@ -360,7 +411,10 @@ def test_unary_ufunc_operations(ufunc):
|
||||
]:
|
||||
# No 0 in the domain of definition
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_avoid_0_input, ((1, 5), (1, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f_avoid_0_input,
|
||||
((1, 5), (1, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.cosh,
|
||||
@@ -375,12 +429,18 @@ def test_unary_ufunc_operations(ufunc):
|
||||
]:
|
||||
# Need a small range of inputs, to avoid to need too much precision
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_which_expects_small_inputs, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f_which_expects_small_inputs,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
else:
|
||||
# Regular case for univariate functions
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
|
||||
@@ -404,7 +464,9 @@ def test_unary_ufunc_operations(ufunc):
|
||||
pytest.param(complicated_topology, ((0, 10),), ["x"]),
|
||||
],
|
||||
)
|
||||
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
def test_compile_function_multiple_outputs(
|
||||
function, input_ranges, list_of_arg_names, default_compilation_configuration
|
||||
):
|
||||
"""Test function compile_numpy_function_into_op_graph for a program with multiple outputs"""
|
||||
|
||||
def data_gen(args):
|
||||
@@ -419,7 +481,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# TODO: For the moment, we don't have really checks, but some printfs. Later,
|
||||
@@ -433,52 +495,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 42, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x + numpy.int32(42), ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x * 2, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: 12 - x, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
|
||||
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"]),
|
||||
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"]),
|
||||
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"]),
|
||||
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"]),
|
||||
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"]),
|
||||
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"]),
|
||||
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"]),
|
||||
pytest.param(random_lut_1b, ((0, 1),), ["x"]),
|
||||
pytest.param(random_lut_2b, ((0, 3),), ["x"]),
|
||||
pytest.param(random_lut_3b, ((0, 7),), ["x"]),
|
||||
pytest.param(random_lut_4b, ((0, 15),), ["x"]),
|
||||
pytest.param(random_lut_5b, ((0, 31),), ["x"]),
|
||||
pytest.param(random_lut_6b, ((0, 63),), ["x"]),
|
||||
pytest.param(random_lut_7b, ((0, 127),), ["x"]),
|
||||
pytest.param(small_fused_table, ((0, 31),), ["x"]),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
"""Test function compile_numpy_function for a program with multiple outputs"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
compiler_engine.run(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 64, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x * 3, ((0, 40),), ["x"]),
|
||||
pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]),
|
||||
@@ -487,7 +504,9 @@ def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_
|
||||
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
def test_compile_and_run_correctness(
|
||||
function, input_ranges, list_of_arg_names, default_compilation_configuration
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def data_gen(args):
|
||||
@@ -502,6 +521,7 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
@@ -529,7 +549,7 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_dot_correctness(size, input_range):
|
||||
def test_compile_and_run_dot_correctness(size, input_range, default_compilation_configuration):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
low, high = input_range
|
||||
@@ -557,6 +577,7 @@ def test_compile_and_run_dot_correctness(size, input_range):
|
||||
function,
|
||||
function_parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
|
||||
@@ -584,7 +605,9 @@ def test_compile_and_run_dot_correctness(size, input_range):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_constant_dot_correctness(size, input_range):
|
||||
def test_compile_and_run_constant_dot_correctness(
|
||||
size, input_range, default_compilation_configuration
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
low, high = input_range
|
||||
@@ -609,11 +632,13 @@ def test_compile_and_run_constant_dot_correctness(size, input_range):
|
||||
left,
|
||||
{"x": EncryptedTensor(Integer(64, False), shape)},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
right_circuit = compile_numpy_function(
|
||||
left,
|
||||
{"x": EncryptedTensor(Integer(64, False), shape)},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = (numpy.random.randint(low, high + 1, size=shape).tolist(),)
|
||||
@@ -622,39 +647,49 @@ def test_compile_and_run_constant_dot_correctness(size, input_range):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
"function,input_bits,list_of_arg_names",
|
||||
[
|
||||
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"], id="identity function (1-bit)"),
|
||||
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"], id="identity function (2-bit)"),
|
||||
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"], id="identity function (3-bit)"),
|
||||
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"], id="identity function (4-bit)"),
|
||||
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"], id="identity function (5-bit)"),
|
||||
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"], id="identity function (6-bit)"),
|
||||
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"], id="identity function (7-bit)"),
|
||||
pytest.param(random_lut_1b, ((0, 1),), ["x"], id="random function (1-bit)"),
|
||||
pytest.param(random_lut_2b, ((0, 3),), ["x"], id="random function (2-bit)"),
|
||||
pytest.param(random_lut_3b, ((0, 7),), ["x"], id="random function (3-bit)"),
|
||||
pytest.param(random_lut_4b, ((0, 15),), ["x"], id="random function (4-bit)"),
|
||||
pytest.param(random_lut_5b, ((0, 31),), ["x"], id="random function (5-bit)"),
|
||||
pytest.param(random_lut_6b, ((0, 63),), ["x"], id="random function (6-bit)"),
|
||||
pytest.param(random_lut_7b, ((0, 127),), ["x"], id="random function (7-bit)"),
|
||||
pytest.param(identity_lut_generator(1), (1,), ["x"], id="identity function (1-bit)"),
|
||||
pytest.param(identity_lut_generator(2), (2,), ["x"], id="identity function (2-bit)"),
|
||||
pytest.param(identity_lut_generator(3), (3,), ["x"], id="identity function (3-bit)"),
|
||||
pytest.param(identity_lut_generator(4), (4,), ["x"], id="identity function (4-bit)"),
|
||||
pytest.param(identity_lut_generator(5), (5,), ["x"], id="identity function (5-bit)"),
|
||||
pytest.param(identity_lut_generator(6), (6,), ["x"], id="identity function (6-bit)"),
|
||||
pytest.param(identity_lut_generator(7), (7,), ["x"], id="identity function (7-bit)"),
|
||||
pytest.param(random_lut_1b, (1,), ["x"], id="random function (1-bit)"),
|
||||
pytest.param(random_lut_2b, (2,), ["x"], id="random function (2-bit)"),
|
||||
pytest.param(random_lut_3b, (3,), ["x"], id="random function (3-bit)"),
|
||||
pytest.param(random_lut_4b, (4,), ["x"], id="random function (4-bit)"),
|
||||
pytest.param(random_lut_5b, (5,), ["x"], id="random function (5-bit)"),
|
||||
pytest.param(random_lut_6b, (6,), ["x"], id="random function (6-bit)"),
|
||||
pytest.param(random_lut_7b, (7,), ["x"], id="random function (7-bit)"),
|
||||
pytest.param(small_fused_table, (5,), ["x"], id="small fused table (5-bits)"),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_lut_correctness(function, input_ranges, list_of_arg_names):
|
||||
def test_compile_and_run_lut_correctness(
|
||||
function,
|
||||
input_bits,
|
||||
list_of_arg_names,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with LUT"""
|
||||
|
||||
input_ranges = tuple((0, 2 ** input_bit - 1) for input_bit in input_bits)
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
arg_name: EncryptedScalar(Integer(input_bit, False))
|
||||
for input_bit, arg_name in zip(input_bits, list_of_arg_names)
|
||||
}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# testing random values
|
||||
@@ -671,7 +706,7 @@ def test_compile_and_run_lut_correctness(function, input_ranges, list_of_arg_nam
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu():
|
||||
def test_compile_function_with_direct_tlu(default_compilation_configuration):
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""
|
||||
|
||||
table = LookupTable([9, 2, 4, 11])
|
||||
@@ -683,13 +718,14 @@ def test_compile_function_with_direct_tlu():
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
[(0,), (1,), (2,), (3,)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu_overflow():
|
||||
def test_compile_function_with_direct_tlu_overflow(default_compilation_configuration):
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup overflow"""
|
||||
|
||||
table = LookupTable([9, 2, 4, 11])
|
||||
@@ -702,7 +738,7 @@ def test_compile_function_with_direct_tlu_overflow():
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(3, is_signed=False))},
|
||||
[(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)],
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
|
||||
@@ -739,7 +775,7 @@ def test_compile_function_with_direct_tlu_overflow():
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_compile(function, parameters, inputset, match):
|
||||
def test_fail_compile(function, parameters, inputset, match, default_compilation_configuration):
|
||||
"""Test function compile_numpy_function_into_op_graph for a program with signed values"""
|
||||
|
||||
try:
|
||||
@@ -747,13 +783,13 @@ def test_fail_compile(function, parameters, inputset, match):
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
except RuntimeError as error:
|
||||
assert str(error) == match
|
||||
|
||||
|
||||
def test_small_inputset():
|
||||
def test_small_inputset_no_fail():
|
||||
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
|
||||
compile_numpy_function_into_op_graph(
|
||||
lambda x: x + 42,
|
||||
@@ -801,7 +837,9 @@ def test_small_inputset_treat_warnings_as_errors():
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
def test_compile_function_with_dot(
|
||||
function, params, shape, ref_graph_str, default_compilation_configuration
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph for a program with np.dot"""
|
||||
|
||||
# This is the exhaust, but if ever we have too long inputs (ie, large 'repeat'),
|
||||
@@ -820,6 +858,7 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
function,
|
||||
params,
|
||||
data_gen(max_for_ij, repeat),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
@@ -840,7 +879,9 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
|
||||
def test_compile_with_show_mlir(
|
||||
function, input_ranges, list_of_arg_names, default_compilation_configuration
|
||||
):
|
||||
"""Test show_mlir option"""
|
||||
|
||||
def data_gen(args):
|
||||
@@ -855,11 +896,12 @@ def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
show_mlir=True,
|
||||
)
|
||||
|
||||
|
||||
def test_compile_too_high_bitwidth():
|
||||
def test_compile_too_high_bitwidth(default_compilation_configuration):
|
||||
"""Check that the check of maximal bitwidth of intermediate data works fine."""
|
||||
|
||||
def function(x, y):
|
||||
@@ -882,6 +924,7 @@ def test_compile_too_high_bitwidth():
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert (
|
||||
@@ -898,4 +941,5 @@ def test_compile_too_high_bitwidth():
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user