mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
test: create default_compilation_configuration fixture
- update test code and use it where appropriate - remove duplicate tests that lacked correctness verification
This commit is contained in:
@@ -9,7 +9,7 @@ from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function
|
||||
|
||||
|
||||
def test_artifacts_export():
|
||||
def test_artifacts_export(default_compilation_configuration):
|
||||
"""Test function to check exporting compilation artifacts"""
|
||||
|
||||
def function(x):
|
||||
@@ -23,6 +23,7 @@ def test_artifacts_export():
|
||||
function,
|
||||
{"x": EncryptedScalar(UnsignedInteger(7))},
|
||||
[(i,) for i in range(10)],
|
||||
default_compilation_configuration,
|
||||
compilation_artifacts=artifacts,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +40,9 @@ def simple_fuse_not_output(x):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_enable_topological_optimizations(test_helpers, function_to_trace, fused):
|
||||
def test_enable_topological_optimizations(
|
||||
test_helpers, function_to_trace, fused, default_compilation_configuration
|
||||
):
|
||||
"""Test function for enable_topological_optimizations flag of compilation configuration"""
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
@@ -50,7 +52,7 @@ def test_enable_topological_optimizations(test_helpers, function_to_trace, fused
|
||||
for param in signature(function_to_trace).parameters.keys()
|
||||
},
|
||||
[(i,) for i in range(10)],
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
op_graph_not_optimized = compile_numpy_function_into_op_graph(
|
||||
function_to_trace,
|
||||
@@ -62,6 +64,7 @@ def test_enable_topological_optimizations(test_helpers, function_to_trace, fused
|
||||
CompilationConfiguration(
|
||||
dump_artifacts_on_unexpected_failures=False,
|
||||
enable_topological_optimizations=False,
|
||||
treat_warnings_as_errors=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def test_draw_graph_with_saving():
|
||||
def test_draw_graph_with_saving(default_compilation_configuration):
|
||||
"""Tests drawing and saving a graph"""
|
||||
|
||||
def function(x):
|
||||
@@ -19,6 +19,7 @@ def test_draw_graph_with_saving():
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
|
||||
@@ -6,7 +6,7 @@ from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def test_get_printable_graph_with_offending_nodes():
|
||||
def test_get_printable_graph_with_offending_nodes(default_compilation_configuration):
|
||||
"""Test get_printable_graph with offending nodes"""
|
||||
|
||||
def function(x):
|
||||
@@ -16,6 +16,7 @@ def test_get_printable_graph_with_offending_nodes():
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: "foo"}
|
||||
|
||||
@@ -216,10 +216,15 @@ def datagen(*args):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter(func, args_dict, args_ranges):
|
||||
def test_mlir_converter(func, args_dict, args_ranges, default_compilation_configuration):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
inputset = datagen(*args_ranges)
|
||||
result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset)
|
||||
result_graph = compile_numpy_function_into_op_graph(
|
||||
func,
|
||||
args_dict,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
@@ -247,7 +252,9 @@ def test_mlir_converter(func, args_dict, args_ranges):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
|
||||
def test_mlir_converter_dot_between_vectors(
|
||||
func, args_dict, args_ranges, default_compilation_configuration
|
||||
):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
assert len(args_dict["x"].shape) == 1
|
||||
assert len(args_dict["y"].shape) == 1
|
||||
@@ -261,6 +268,7 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
|
||||
(numpy.array([data[0]] * n), numpy.array([data[1]] * n))
|
||||
for data in datagen(*args_ranges)
|
||||
),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
@@ -268,7 +276,7 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
|
||||
compiler.round_trip(mlir_result)
|
||||
|
||||
|
||||
def test_mlir_converter_dot_vector_and_constant():
|
||||
def test_mlir_converter_dot_vector_and_constant(default_compilation_configuration):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
|
||||
def left_dot_with_constant(x):
|
||||
@@ -281,6 +289,7 @@ def test_mlir_converter_dot_vector_and_constant():
|
||||
left_dot_with_constant,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
left_mlir = left_converter.convert(left_graph)
|
||||
@@ -289,6 +298,7 @@ def test_mlir_converter_dot_vector_and_constant():
|
||||
right_dot_with_constant,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
right_mlir = right_converter.convert(right_graph)
|
||||
|
||||
@@ -6,7 +6,7 @@ import concrete.numpy as hnp
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
|
||||
|
||||
def test_circuit_str():
|
||||
def test_circuit_str(default_compilation_configuration):
|
||||
"""Test function for `__str__` method of `Circuit`"""
|
||||
|
||||
def f(x):
|
||||
@@ -15,12 +15,12 @@ def test_circuit_str():
|
||||
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset)
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
assert str(circuit) == get_printable_graph(circuit.opgraph, show_data_types=True)
|
||||
|
||||
|
||||
def test_circuit_draw():
|
||||
def test_circuit_draw(default_compilation_configuration):
|
||||
"""Test function for `draw` method of `Circuit`"""
|
||||
|
||||
def f(x):
|
||||
@@ -29,13 +29,13 @@ def test_circuit_draw():
|
||||
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset)
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
assert filecmp.cmp(circuit.draw(), draw_graph(circuit.opgraph))
|
||||
assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.opgraph, vertical=False))
|
||||
|
||||
|
||||
def test_circuit_run():
|
||||
def test_circuit_run(default_compilation_configuration):
|
||||
"""Test function for `run` method of `Circuit`"""
|
||||
|
||||
def f(x):
|
||||
@@ -44,7 +44,7 @@ def test_circuit_run():
|
||||
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset)
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
for x in inputset:
|
||||
assert circuit.run(*x) == circuit.engine.run(*x)
|
||||
|
||||
@@ -8,6 +8,7 @@ import networkx as nx
|
||||
import networkx.algorithms.isomorphism as iso
|
||||
import pytest
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.representation.intermediate import (
|
||||
ALL_IR_NODES,
|
||||
Add,
|
||||
@@ -228,3 +229,12 @@ class TestHelpers:
|
||||
def test_helpers():
|
||||
"""Fixture to return the static helper class"""
|
||||
return TestHelpers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_compilation_configuration():
|
||||
"""Return the default test compilation configuration"""
|
||||
return CompilationConfiguration(
|
||||
dump_artifacts_on_unexpected_failures=False,
|
||||
treat_warnings_as_errors=True,
|
||||
)
|
||||
|
||||
@@ -224,7 +224,12 @@ def check_is_good_execution(compiler_engine, function, args):
|
||||
)
|
||||
|
||||
|
||||
def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input_ranges):
|
||||
def subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc,
|
||||
upper_function,
|
||||
input_ranges,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def get_function(ufunc, upper_function):
|
||||
@@ -242,6 +247,7 @@ def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
@@ -249,7 +255,13 @@ def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
|
||||
def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, input_ranges):
|
||||
def subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
upper_function,
|
||||
c,
|
||||
input_ranges,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def get_function(ufunc, upper_function):
|
||||
@@ -267,6 +279,7 @@ def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, i
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
@@ -278,54 +291,86 @@ def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, i
|
||||
"ufunc",
|
||||
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 2],
|
||||
)
|
||||
def test_binary_ufunc_operations(ufunc):
|
||||
def test_binary_ufunc_operations(ufunc, default_compilation_configuration):
|
||||
"""Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
if ufunc in [numpy.power, numpy.float_power]:
|
||||
# Need small constants to keep results really small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_one, 3, ((0, 4), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_one,
|
||||
3,
|
||||
((0, 4), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.lcm, numpy.left_shift]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_one, 3, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_one,
|
||||
3,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
else:
|
||||
# General case
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_one, 41, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_one,
|
||||
41,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
if ufunc in [numpy.power, numpy.float_power]:
|
||||
# Need small constants to keep results really small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 4), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
2,
|
||||
((0, 4), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 31, ((1, 5), (1, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
31,
|
||||
((1, 5), (1, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.lcm, numpy.left_shift]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
2,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.ldexp]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 2, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
2,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
else:
|
||||
# General case
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_binary_f_two, 42, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_binary_f_two,
|
||||
42,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ufunc", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1]
|
||||
)
|
||||
def test_unary_ufunc_operations(ufunc):
|
||||
def test_unary_ufunc_operations(ufunc, default_compilation_configuration):
|
||||
"""Test unary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
if ufunc in [
|
||||
numpy.degrees,
|
||||
@@ -333,14 +378,20 @@ def test_unary_ufunc_operations(ufunc):
|
||||
]:
|
||||
# Need to reduce the output value, to avoid to need too much precision
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_which_has_large_outputs, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f_which_has_large_outputs,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.negative,
|
||||
]:
|
||||
# Need to turn the input into a float
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_with_float_inputs, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f_with_float_inputs,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.invert,
|
||||
@@ -360,7 +411,10 @@ def test_unary_ufunc_operations(ufunc):
|
||||
]:
|
||||
# No 0 in the domain of definition
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_avoid_0_input, ((1, 5), (1, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f_avoid_0_input,
|
||||
((1, 5), (1, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.cosh,
|
||||
@@ -375,12 +429,18 @@ def test_unary_ufunc_operations(ufunc):
|
||||
]:
|
||||
# Need a small range of inputs, to avoid to need too much precision
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f_which_expects_small_inputs, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f_which_expects_small_inputs,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
else:
|
||||
# Regular case for univariate functions
|
||||
subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc, mix_x_and_y_and_call_f, ((0, 5), (0, 5))
|
||||
ufunc,
|
||||
mix_x_and_y_and_call_f,
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
|
||||
@@ -404,7 +464,9 @@ def test_unary_ufunc_operations(ufunc):
|
||||
pytest.param(complicated_topology, ((0, 10),), ["x"]),
|
||||
],
|
||||
)
|
||||
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
def test_compile_function_multiple_outputs(
|
||||
function, input_ranges, list_of_arg_names, default_compilation_configuration
|
||||
):
|
||||
"""Test function compile_numpy_function_into_op_graph for a program with multiple outputs"""
|
||||
|
||||
def data_gen(args):
|
||||
@@ -419,7 +481,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# TODO: For the moment, we don't have really checks, but some printfs. Later,
|
||||
@@ -433,52 +495,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 42, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x + numpy.int32(42), ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x * 2, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: 12 - x, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
|
||||
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"]),
|
||||
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"]),
|
||||
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"]),
|
||||
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"]),
|
||||
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"]),
|
||||
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"]),
|
||||
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"]),
|
||||
pytest.param(random_lut_1b, ((0, 1),), ["x"]),
|
||||
pytest.param(random_lut_2b, ((0, 3),), ["x"]),
|
||||
pytest.param(random_lut_3b, ((0, 7),), ["x"]),
|
||||
pytest.param(random_lut_4b, ((0, 15),), ["x"]),
|
||||
pytest.param(random_lut_5b, ((0, 31),), ["x"]),
|
||||
pytest.param(random_lut_6b, ((0, 63),), ["x"]),
|
||||
pytest.param(random_lut_7b, ((0, 127),), ["x"]),
|
||||
pytest.param(small_fused_table, ((0, 31),), ["x"]),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
"""Test function compile_numpy_function for a program with multiple outputs"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
compiler_engine.run(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 64, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x * 3, ((0, 40),), ["x"]),
|
||||
pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]),
|
||||
@@ -487,7 +504,9 @@ def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_
|
||||
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
def test_compile_and_run_correctness(
|
||||
function, input_ranges, list_of_arg_names, default_compilation_configuration
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def data_gen(args):
|
||||
@@ -502,6 +521,7 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
@@ -529,7 +549,7 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_dot_correctness(size, input_range):
|
||||
def test_compile_and_run_dot_correctness(size, input_range, default_compilation_configuration):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
low, high = input_range
|
||||
@@ -557,6 +577,7 @@ def test_compile_and_run_dot_correctness(size, input_range):
|
||||
function,
|
||||
function_parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
|
||||
@@ -584,7 +605,9 @@ def test_compile_and_run_dot_correctness(size, input_range):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_constant_dot_correctness(size, input_range):
|
||||
def test_compile_and_run_constant_dot_correctness(
|
||||
size, input_range, default_compilation_configuration
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
low, high = input_range
|
||||
@@ -609,11 +632,13 @@ def test_compile_and_run_constant_dot_correctness(size, input_range):
|
||||
left,
|
||||
{"x": EncryptedTensor(Integer(64, False), shape)},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
right_circuit = compile_numpy_function(
|
||||
left,
|
||||
{"x": EncryptedTensor(Integer(64, False), shape)},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = (numpy.random.randint(low, high + 1, size=shape).tolist(),)
|
||||
@@ -622,39 +647,49 @@ def test_compile_and_run_constant_dot_correctness(size, input_range):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
"function,input_bits,list_of_arg_names",
|
||||
[
|
||||
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"], id="identity function (1-bit)"),
|
||||
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"], id="identity function (2-bit)"),
|
||||
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"], id="identity function (3-bit)"),
|
||||
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"], id="identity function (4-bit)"),
|
||||
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"], id="identity function (5-bit)"),
|
||||
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"], id="identity function (6-bit)"),
|
||||
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"], id="identity function (7-bit)"),
|
||||
pytest.param(random_lut_1b, ((0, 1),), ["x"], id="random function (1-bit)"),
|
||||
pytest.param(random_lut_2b, ((0, 3),), ["x"], id="random function (2-bit)"),
|
||||
pytest.param(random_lut_3b, ((0, 7),), ["x"], id="random function (3-bit)"),
|
||||
pytest.param(random_lut_4b, ((0, 15),), ["x"], id="random function (4-bit)"),
|
||||
pytest.param(random_lut_5b, ((0, 31),), ["x"], id="random function (5-bit)"),
|
||||
pytest.param(random_lut_6b, ((0, 63),), ["x"], id="random function (6-bit)"),
|
||||
pytest.param(random_lut_7b, ((0, 127),), ["x"], id="random function (7-bit)"),
|
||||
pytest.param(identity_lut_generator(1), (1,), ["x"], id="identity function (1-bit)"),
|
||||
pytest.param(identity_lut_generator(2), (2,), ["x"], id="identity function (2-bit)"),
|
||||
pytest.param(identity_lut_generator(3), (3,), ["x"], id="identity function (3-bit)"),
|
||||
pytest.param(identity_lut_generator(4), (4,), ["x"], id="identity function (4-bit)"),
|
||||
pytest.param(identity_lut_generator(5), (5,), ["x"], id="identity function (5-bit)"),
|
||||
pytest.param(identity_lut_generator(6), (6,), ["x"], id="identity function (6-bit)"),
|
||||
pytest.param(identity_lut_generator(7), (7,), ["x"], id="identity function (7-bit)"),
|
||||
pytest.param(random_lut_1b, (1,), ["x"], id="random function (1-bit)"),
|
||||
pytest.param(random_lut_2b, (2,), ["x"], id="random function (2-bit)"),
|
||||
pytest.param(random_lut_3b, (3,), ["x"], id="random function (3-bit)"),
|
||||
pytest.param(random_lut_4b, (4,), ["x"], id="random function (4-bit)"),
|
||||
pytest.param(random_lut_5b, (5,), ["x"], id="random function (5-bit)"),
|
||||
pytest.param(random_lut_6b, (6,), ["x"], id="random function (6-bit)"),
|
||||
pytest.param(random_lut_7b, (7,), ["x"], id="random function (7-bit)"),
|
||||
pytest.param(small_fused_table, (5,), ["x"], id="small fused table (5-bits)"),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_lut_correctness(function, input_ranges, list_of_arg_names):
|
||||
def test_compile_and_run_lut_correctness(
|
||||
function,
|
||||
input_bits,
|
||||
list_of_arg_names,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with LUT"""
|
||||
|
||||
input_ranges = tuple((0, 2 ** input_bit - 1) for input_bit in input_bits)
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
arg_name: EncryptedScalar(Integer(input_bit, False))
|
||||
for input_bit, arg_name in zip(input_bits, list_of_arg_names)
|
||||
}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# testing random values
|
||||
@@ -671,7 +706,7 @@ def test_compile_and_run_lut_correctness(function, input_ranges, list_of_arg_nam
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu():
|
||||
def test_compile_function_with_direct_tlu(default_compilation_configuration):
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""
|
||||
|
||||
table = LookupTable([9, 2, 4, 11])
|
||||
@@ -683,13 +718,14 @@ def test_compile_function_with_direct_tlu():
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
[(0,), (1,), (2,), (3,)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu_overflow():
|
||||
def test_compile_function_with_direct_tlu_overflow(default_compilation_configuration):
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup overflow"""
|
||||
|
||||
table = LookupTable([9, 2, 4, 11])
|
||||
@@ -702,7 +738,7 @@ def test_compile_function_with_direct_tlu_overflow():
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(3, is_signed=False))},
|
||||
[(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)],
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
|
||||
@@ -739,7 +775,7 @@ def test_compile_function_with_direct_tlu_overflow():
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_compile(function, parameters, inputset, match):
|
||||
def test_fail_compile(function, parameters, inputset, match, default_compilation_configuration):
|
||||
"""Test function compile_numpy_function_into_op_graph for a program with signed values"""
|
||||
|
||||
try:
|
||||
@@ -747,13 +783,13 @@ def test_fail_compile(function, parameters, inputset, match):
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
except RuntimeError as error:
|
||||
assert str(error) == match
|
||||
|
||||
|
||||
def test_small_inputset():
|
||||
def test_small_inputset_no_fail():
|
||||
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
|
||||
compile_numpy_function_into_op_graph(
|
||||
lambda x: x + 42,
|
||||
@@ -801,7 +837,9 @@ def test_small_inputset_treat_warnings_as_errors():
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
def test_compile_function_with_dot(
|
||||
function, params, shape, ref_graph_str, default_compilation_configuration
|
||||
):
|
||||
"""Test compile_numpy_function_into_op_graph for a program with np.dot"""
|
||||
|
||||
# This is the exhaust, but if ever we have too long inputs (ie, large 'repeat'),
|
||||
@@ -820,6 +858,7 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
function,
|
||||
params,
|
||||
data_gen(max_for_ij, repeat),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
@@ -840,7 +879,9 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
|
||||
def test_compile_with_show_mlir(
|
||||
function, input_ranges, list_of_arg_names, default_compilation_configuration
|
||||
):
|
||||
"""Test show_mlir option"""
|
||||
|
||||
def data_gen(args):
|
||||
@@ -855,11 +896,12 @@ def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
show_mlir=True,
|
||||
)
|
||||
|
||||
|
||||
def test_compile_too_high_bitwidth():
|
||||
def test_compile_too_high_bitwidth(default_compilation_configuration):
|
||||
"""Check that the check of maximal bitwidth of intermediate data works fine."""
|
||||
|
||||
def function(x, y):
|
||||
@@ -882,6 +924,7 @@ def test_compile_too_high_bitwidth():
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert (
|
||||
@@ -898,4 +941,5 @@ def test_compile_too_high_bitwidth():
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user