test: create default_compilation_configuration fixture

- update test code and use it where appropriate
- remove duplicate tests that lacked correctness verification
This commit is contained in:
Arthur Meyre
2021-10-18 16:09:11 +02:00
parent 82688206f7
commit 384026364e
8 changed files with 178 additions and 108 deletions

View File

@@ -9,7 +9,7 @@ from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function
def test_artifacts_export():
def test_artifacts_export(default_compilation_configuration):
"""Test function to check exporting compilation artifacts"""
def function(x):
@@ -23,6 +23,7 @@ def test_artifacts_export():
function,
{"x": EncryptedScalar(UnsignedInteger(7))},
[(i,) for i in range(10)],
default_compilation_configuration,
compilation_artifacts=artifacts,
)

View File

@@ -40,7 +40,9 @@ def simple_fuse_not_output(x):
),
],
)
def test_enable_topological_optimizations(test_helpers, function_to_trace, fused):
def test_enable_topological_optimizations(
test_helpers, function_to_trace, fused, default_compilation_configuration
):
"""Test function for enable_topological_optimizations flag of compilation configuration"""
op_graph = compile_numpy_function_into_op_graph(
@@ -50,7 +52,7 @@ def test_enable_topological_optimizations(test_helpers, function_to_trace, fused
for param in signature(function_to_trace).parameters.keys()
},
[(i,) for i in range(10)],
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
default_compilation_configuration,
)
op_graph_not_optimized = compile_numpy_function_into_op_graph(
function_to_trace,
@@ -62,6 +64,7 @@ def test_enable_topological_optimizations(test_helpers, function_to_trace, fused
CompilationConfiguration(
dump_artifacts_on_unexpected_failures=False,
enable_topological_optimizations=False,
treat_warnings_as_errors=True,
),
)

View File

@@ -9,7 +9,7 @@ from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph
def test_draw_graph_with_saving():
def test_draw_graph_with_saving(default_compilation_configuration):
"""Tests drawing and saving a graph"""
def function(x):
@@ -19,6 +19,7 @@ def test_draw_graph_with_saving():
function,
{"x": EncryptedScalar(Integer(7, True))},
[(i,) for i in range(-5, 5)],
default_compilation_configuration,
)
with tempfile.TemporaryDirectory() as tmp:

View File

@@ -6,7 +6,7 @@ from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph
def test_get_printable_graph_with_offending_nodes():
def test_get_printable_graph_with_offending_nodes(default_compilation_configuration):
"""Test get_printable_graph with offending nodes"""
def function(x):
@@ -16,6 +16,7 @@ def test_get_printable_graph_with_offending_nodes():
function,
{"x": EncryptedScalar(Integer(7, True))},
[(i,) for i in range(-5, 5)],
default_compilation_configuration,
)
highlighted_nodes = {opgraph.input_nodes[0]: "foo"}

View File

@@ -216,10 +216,15 @@ def datagen(*args):
),
],
)
def test_mlir_converter(func, args_dict, args_ranges):
def test_mlir_converter(func, args_dict, args_ranges, default_compilation_configuration):
"""Test the conversion to MLIR by calling the parser from the compiler"""
inputset = datagen(*args_ranges)
result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset)
result_graph = compile_numpy_function_into_op_graph(
func,
args_dict,
inputset,
default_compilation_configuration,
)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error
@@ -247,7 +252,9 @@ def test_mlir_converter(func, args_dict, args_ranges):
),
],
)
def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
def test_mlir_converter_dot_between_vectors(
func, args_dict, args_ranges, default_compilation_configuration
):
"""Test the conversion to MLIR by calling the parser from the compiler"""
assert len(args_dict["x"].shape) == 1
assert len(args_dict["y"].shape) == 1
@@ -261,6 +268,7 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
(numpy.array([data[0]] * n), numpy.array([data[1]] * n))
for data in datagen(*args_ranges)
),
default_compilation_configuration,
)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
@@ -268,7 +276,7 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
compiler.round_trip(mlir_result)
def test_mlir_converter_dot_vector_and_constant():
def test_mlir_converter_dot_vector_and_constant(default_compilation_configuration):
"""Test the conversion to MLIR by calling the parser from the compiler"""
def left_dot_with_constant(x):
@@ -281,6 +289,7 @@ def test_mlir_converter_dot_vector_and_constant():
left_dot_with_constant,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
default_compilation_configuration,
)
left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
left_mlir = left_converter.convert(left_graph)
@@ -289,6 +298,7 @@ def test_mlir_converter_dot_vector_and_constant():
right_dot_with_constant,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
default_compilation_configuration,
)
right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
right_mlir = right_converter.convert(right_graph)

View File

@@ -6,7 +6,7 @@ import concrete.numpy as hnp
from concrete.common.debugging import draw_graph, get_printable_graph
def test_circuit_str():
def test_circuit_str(default_compilation_configuration):
"""Test function for `__str__` method of `Circuit`"""
def f(x):
@@ -15,12 +15,12 @@ def test_circuit_str():
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
inputset = [(i,) for i in range(2 ** 3)]
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset)
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
assert str(circuit) == get_printable_graph(circuit.opgraph, show_data_types=True)
def test_circuit_draw():
def test_circuit_draw(default_compilation_configuration):
"""Test function for `draw` method of `Circuit`"""
def f(x):
@@ -29,13 +29,13 @@ def test_circuit_draw():
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
inputset = [(i,) for i in range(2 ** 3)]
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset)
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
assert filecmp.cmp(circuit.draw(), draw_graph(circuit.opgraph))
assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.opgraph, vertical=False))
def test_circuit_run():
def test_circuit_run(default_compilation_configuration):
"""Test function for `run` method of `Circuit`"""
def f(x):
@@ -44,7 +44,7 @@ def test_circuit_run():
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
inputset = [(i,) for i in range(2 ** 3)]
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset)
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
for x in inputset:
assert circuit.run(*x) == circuit.engine.run(*x)

View File

@@ -8,6 +8,7 @@ import networkx as nx
import networkx.algorithms.isomorphism as iso
import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.representation.intermediate import (
ALL_IR_NODES,
Add,
@@ -228,3 +229,12 @@ class TestHelpers:
def test_helpers():
"""Fixture to return the static helper class"""
return TestHelpers
@pytest.fixture
def default_compilation_configuration():
"""Return the default test compilation configuration"""
return CompilationConfiguration(
dump_artifacts_on_unexpected_failures=False,
treat_warnings_as_errors=True,
)

View File

@@ -224,7 +224,12 @@ def check_is_good_execution(compiler_engine, function, args):
)
def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input_ranges):
def subtest_compile_and_run_unary_ufunc_correctness(
ufunc,
upper_function,
input_ranges,
default_compilation_configuration,
):
"""Test correctness of results when running a compiled function"""
def get_function(ufunc, upper_function):
@@ -242,6 +247,7 @@ def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
args = [random.randint(low, high) for (low, high) in input_ranges]
@@ -249,7 +255,13 @@ def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input
check_is_good_execution(compiler_engine, function, args)
def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, input_ranges):
def subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
upper_function,
c,
input_ranges,
default_compilation_configuration,
):
"""Test correctness of results when running a compiled function"""
def get_function(ufunc, upper_function):
@@ -267,6 +279,7 @@ def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, i
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
args = [random.randint(low, high) for (low, high) in input_ranges]
@@ -278,54 +291,86 @@ def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, i
"ufunc",
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 2],
)
def test_binary_ufunc_operations(ufunc):
def test_binary_ufunc_operations(ufunc, default_compilation_configuration):
"""Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
if ufunc in [numpy.power, numpy.float_power]:
# Need small constants to keep results really small
subtest_compile_and_run_binary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_binary_f_one, 3, ((0, 4), (0, 5))
ufunc,
mix_x_and_y_and_call_binary_f_one,
3,
((0, 4), (0, 5)),
default_compilation_configuration,
)
elif ufunc in [numpy.lcm, numpy.left_shift]:
# Need small constants to keep results sufficiently small
subtest_compile_and_run_binary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_binary_f_one, 3, ((0, 5), (0, 5))
ufunc,
mix_x_and_y_and_call_binary_f_one,
3,
((0, 5), (0, 5)),
default_compilation_configuration,
)
else:
# General case
subtest_compile_and_run_binary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_binary_f_one, 41, ((0, 5), (0, 5))
ufunc,
mix_x_and_y_and_call_binary_f_one,
41,
((0, 5), (0, 5)),
default_compilation_configuration,
)
if ufunc in [numpy.power, numpy.float_power]:
# Need small constants to keep results really small
subtest_compile_and_run_binary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 4), (0, 5))
ufunc,
mix_x_and_y_and_call_binary_f_two,
2,
((0, 4), (0, 5)),
default_compilation_configuration,
)
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
subtest_compile_and_run_binary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_binary_f_two, 31, ((1, 5), (1, 5))
ufunc,
mix_x_and_y_and_call_binary_f_two,
31,
((1, 5), (1, 5)),
default_compilation_configuration,
)
elif ufunc in [numpy.lcm, numpy.left_shift]:
# Need small constants to keep results sufficiently small
subtest_compile_and_run_binary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 5), (0, 5))
ufunc,
mix_x_and_y_and_call_binary_f_two,
2,
((0, 5), (0, 5)),
default_compilation_configuration,
)
elif ufunc in [numpy.ldexp]:
# Need small constants to keep results sufficiently small
subtest_compile_and_run_binary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 5), (0, 5))
ufunc,
mix_x_and_y_and_call_binary_f_two,
2,
((0, 5), (0, 5)),
default_compilation_configuration,
)
else:
# General case
subtest_compile_and_run_binary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_binary_f_two, 42, ((0, 5), (0, 5))
ufunc,
mix_x_and_y_and_call_binary_f_two,
42,
((0, 5), (0, 5)),
default_compilation_configuration,
)
@pytest.mark.parametrize(
"ufunc", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1]
)
def test_unary_ufunc_operations(ufunc):
def test_unary_ufunc_operations(ufunc, default_compilation_configuration):
"""Test unary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
if ufunc in [
numpy.degrees,
@@ -333,14 +378,20 @@ def test_unary_ufunc_operations(ufunc):
]:
# Need to reduce the output value, to avoid to need too much precision
subtest_compile_and_run_unary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_f_which_has_large_outputs, ((0, 5), (0, 5))
ufunc,
mix_x_and_y_and_call_f_which_has_large_outputs,
((0, 5), (0, 5)),
default_compilation_configuration,
)
elif ufunc in [
numpy.negative,
]:
# Need to turn the input into a float
subtest_compile_and_run_unary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_f_with_float_inputs, ((0, 5), (0, 5))
ufunc,
mix_x_and_y_and_call_f_with_float_inputs,
((0, 5), (0, 5)),
default_compilation_configuration,
)
elif ufunc in [
numpy.invert,
@@ -360,7 +411,10 @@ def test_unary_ufunc_operations(ufunc):
]:
# No 0 in the domain of definition
subtest_compile_and_run_unary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_f_avoid_0_input, ((1, 5), (1, 5))
ufunc,
mix_x_and_y_and_call_f_avoid_0_input,
((1, 5), (1, 5)),
default_compilation_configuration,
)
elif ufunc in [
numpy.cosh,
@@ -375,12 +429,18 @@ def test_unary_ufunc_operations(ufunc):
]:
# Need a small range of inputs, to avoid to need too much precision
subtest_compile_and_run_unary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_f_which_expects_small_inputs, ((0, 5), (0, 5))
ufunc,
mix_x_and_y_and_call_f_which_expects_small_inputs,
((0, 5), (0, 5)),
default_compilation_configuration,
)
else:
# Regular case for univariate functions
subtest_compile_and_run_unary_ufunc_correctness(
ufunc, mix_x_and_y_and_call_f, ((0, 5), (0, 5))
ufunc,
mix_x_and_y_and_call_f,
((0, 5), (0, 5)),
default_compilation_configuration,
)
@@ -404,7 +464,9 @@ def test_unary_ufunc_operations(ufunc):
pytest.param(complicated_topology, ((0, 10),), ["x"]),
],
)
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
def test_compile_function_multiple_outputs(
function, input_ranges, list_of_arg_names, default_compilation_configuration
):
"""Test function compile_numpy_function_into_op_graph for a program with multiple outputs"""
def data_gen(args):
@@ -419,7 +481,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
default_compilation_configuration,
)
# TODO: For the moment, we don't have really checks, but some printfs. Later,
@@ -433,52 +495,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 42, ((0, 10),), ["x"]),
pytest.param(lambda x: x + numpy.int32(42), ((0, 10),), ["x"]),
pytest.param(lambda x: x * 2, ((0, 10),), ["x"]),
pytest.param(lambda x: 12 - x, ((0, 10),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"]),
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"]),
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"]),
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"]),
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"]),
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"]),
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"]),
pytest.param(random_lut_1b, ((0, 1),), ["x"]),
pytest.param(random_lut_2b, ((0, 3),), ["x"]),
pytest.param(random_lut_3b, ((0, 7),), ["x"]),
pytest.param(random_lut_4b, ((0, 15),), ["x"]),
pytest.param(random_lut_5b, ((0, 31),), ["x"]),
pytest.param(random_lut_6b, ((0, 63),), ["x"]),
pytest.param(random_lut_7b, ((0, 127),), ["x"]),
pytest.param(small_fused_table, ((0, 31),), ["x"]),
],
)
def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):
"""Test function compile_numpy_function for a program with multiple outputs"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
}
compiler_engine = compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)
args = [random.randint(low, high) for (low, high) in input_ranges]
compiler_engine.run(*args)
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 64, ((0, 10),), ["x"]),
pytest.param(lambda x: x * 3, ((0, 40),), ["x"]),
pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]),
@@ -487,7 +504,9 @@ def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
],
)
def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
def test_compile_and_run_correctness(
function, input_ranges, list_of_arg_names, default_compilation_configuration
):
"""Test correctness of results when running a compiled function"""
def data_gen(args):
@@ -502,6 +521,7 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
args = [random.randint(low, high) for (low, high) in input_ranges]
@@ -529,7 +549,7 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
),
],
)
def test_compile_and_run_dot_correctness(size, input_range):
def test_compile_and_run_dot_correctness(size, input_range, default_compilation_configuration):
"""Test correctness of results when running a compiled function"""
low, high = input_range
@@ -557,6 +577,7 @@ def test_compile_and_run_dot_correctness(size, input_range):
function,
function_parameters,
inputset,
default_compilation_configuration,
)
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
@@ -584,7 +605,9 @@ def test_compile_and_run_dot_correctness(size, input_range):
),
],
)
def test_compile_and_run_constant_dot_correctness(size, input_range):
def test_compile_and_run_constant_dot_correctness(
size, input_range, default_compilation_configuration
):
"""Test correctness of results when running a compiled function"""
low, high = input_range
@@ -609,11 +632,13 @@ def test_compile_and_run_constant_dot_correctness(size, input_range):
left,
{"x": EncryptedTensor(Integer(64, False), shape)},
inputset,
default_compilation_configuration,
)
right_circuit = compile_numpy_function(
left,
{"x": EncryptedTensor(Integer(64, False), shape)},
inputset,
default_compilation_configuration,
)
args = (numpy.random.randint(low, high + 1, size=shape).tolist(),)
@@ -622,39 +647,49 @@ def test_compile_and_run_constant_dot_correctness(size, input_range):
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
"function,input_bits,list_of_arg_names",
[
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"], id="identity function (1-bit)"),
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"], id="identity function (2-bit)"),
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"], id="identity function (3-bit)"),
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"], id="identity function (4-bit)"),
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"], id="identity function (5-bit)"),
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"], id="identity function (6-bit)"),
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"], id="identity function (7-bit)"),
pytest.param(random_lut_1b, ((0, 1),), ["x"], id="random function (1-bit)"),
pytest.param(random_lut_2b, ((0, 3),), ["x"], id="random function (2-bit)"),
pytest.param(random_lut_3b, ((0, 7),), ["x"], id="random function (3-bit)"),
pytest.param(random_lut_4b, ((0, 15),), ["x"], id="random function (4-bit)"),
pytest.param(random_lut_5b, ((0, 31),), ["x"], id="random function (5-bit)"),
pytest.param(random_lut_6b, ((0, 63),), ["x"], id="random function (6-bit)"),
pytest.param(random_lut_7b, ((0, 127),), ["x"], id="random function (7-bit)"),
pytest.param(identity_lut_generator(1), (1,), ["x"], id="identity function (1-bit)"),
pytest.param(identity_lut_generator(2), (2,), ["x"], id="identity function (2-bit)"),
pytest.param(identity_lut_generator(3), (3,), ["x"], id="identity function (3-bit)"),
pytest.param(identity_lut_generator(4), (4,), ["x"], id="identity function (4-bit)"),
pytest.param(identity_lut_generator(5), (5,), ["x"], id="identity function (5-bit)"),
pytest.param(identity_lut_generator(6), (6,), ["x"], id="identity function (6-bit)"),
pytest.param(identity_lut_generator(7), (7,), ["x"], id="identity function (7-bit)"),
pytest.param(random_lut_1b, (1,), ["x"], id="random function (1-bit)"),
pytest.param(random_lut_2b, (2,), ["x"], id="random function (2-bit)"),
pytest.param(random_lut_3b, (3,), ["x"], id="random function (3-bit)"),
pytest.param(random_lut_4b, (4,), ["x"], id="random function (4-bit)"),
pytest.param(random_lut_5b, (5,), ["x"], id="random function (5-bit)"),
pytest.param(random_lut_6b, (6,), ["x"], id="random function (6-bit)"),
pytest.param(random_lut_7b, (7,), ["x"], id="random function (7-bit)"),
pytest.param(small_fused_table, (5,), ["x"], id="small fused table (5-bits)"),
],
)
def test_compile_and_run_lut_correctness(function, input_ranges, list_of_arg_names):
def test_compile_and_run_lut_correctness(
function,
input_bits,
list_of_arg_names,
default_compilation_configuration,
):
"""Test correctness of results when running a compiled function with LUT"""
input_ranges = tuple((0, 2 ** input_bit - 1) for input_bit in input_bits)
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
arg_name: EncryptedScalar(Integer(input_bit, False))
for input_bit, arg_name in zip(input_bits, list_of_arg_names)
}
compiler_engine = compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
# testing random values
@@ -671,7 +706,7 @@ def test_compile_and_run_lut_correctness(function, input_ranges, list_of_arg_nam
check_is_good_execution(compiler_engine, function, args)
def test_compile_function_with_direct_tlu():
def test_compile_function_with_direct_tlu(default_compilation_configuration):
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""
table = LookupTable([9, 2, 4, 11])
@@ -683,13 +718,14 @@ def test_compile_function_with_direct_tlu():
function,
{"x": EncryptedScalar(Integer(2, is_signed=False))},
[(0,), (1,), (2,), (3,)],
default_compilation_configuration,
)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")
def test_compile_function_with_direct_tlu_overflow():
def test_compile_function_with_direct_tlu_overflow(default_compilation_configuration):
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup overflow"""
table = LookupTable([9, 2, 4, 11])
@@ -702,7 +738,7 @@ def test_compile_function_with_direct_tlu_overflow():
function,
{"x": EncryptedScalar(Integer(3, is_signed=False))},
[(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)],
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
default_compilation_configuration,
)
@@ -739,7 +775,7 @@ def test_compile_function_with_direct_tlu_overflow():
),
],
)
def test_fail_compile(function, parameters, inputset, match):
def test_fail_compile(function, parameters, inputset, match, default_compilation_configuration):
"""Test function compile_numpy_function_into_op_graph for a program with signed values"""
try:
@@ -747,13 +783,13 @@ def test_fail_compile(function, parameters, inputset, match):
function,
parameters,
inputset,
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
default_compilation_configuration,
)
except RuntimeError as error:
assert str(error) == match
def test_small_inputset():
def test_small_inputset_no_fail():
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
compile_numpy_function_into_op_graph(
lambda x: x + 42,
@@ -801,7 +837,9 @@ def test_small_inputset_treat_warnings_as_errors():
# pylint: enable=unnecessary-lambda
],
)
def test_compile_function_with_dot(function, params, shape, ref_graph_str):
def test_compile_function_with_dot(
function, params, shape, ref_graph_str, default_compilation_configuration
):
"""Test compile_numpy_function_into_op_graph for a program with np.dot"""
# This is the exhaust, but if ever we have too long inputs (ie, large 'repeat'),
@@ -820,6 +858,7 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
function,
params,
data_gen(max_for_ij, repeat),
default_compilation_configuration,
)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
@@ -840,7 +879,9 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
],
)
def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
def test_compile_with_show_mlir(
function, input_ranges, list_of_arg_names, default_compilation_configuration
):
"""Test show_mlir option"""
def data_gen(args):
@@ -855,11 +896,12 @@ def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
show_mlir=True,
)
def test_compile_too_high_bitwidth():
def test_compile_too_high_bitwidth(default_compilation_configuration):
"""Check that the check of maximal bitwidth of intermediate data works fine."""
def function(x, y):
@@ -882,6 +924,7 @@ def test_compile_too_high_bitwidth():
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
assert (
@@ -898,4 +941,5 @@ def test_compile_too_high_bitwidth():
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)