Files
concrete/tests/common/compilation/test_configuration.py
Arthur Meyre 48aaed9ee8 refactor: disallow tuple input in inputset for 1-input functions
- handle the case properly once we add tuple support in #772

closes #971
2021-12-02 11:59:42 +01:00

78 lines
2.3 KiB
Python

"""Test file for compilation configuration"""
from inspect import signature
import numpy
import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.integers import Integer
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
def no_fuse(x):
"""No fuse"""
return x + 2
def simple_fuse_not_output(x):
"""Simple fuse not output"""
intermediate = x.astype(numpy.float64)
intermediate = intermediate.astype(numpy.uint32)
return intermediate + 2
@pytest.mark.parametrize(
"function_to_trace,fused",
[
pytest.param(
no_fuse,
False,
id="no_fuse",
),
pytest.param(
simple_fuse_not_output,
True,
id="simple_fuse_not_output",
),
],
)
def test_enable_topological_optimizations(
test_helpers, function_to_trace, fused, default_compilation_configuration
):
"""Test function for enable_topological_optimizations flag of compilation configuration"""
op_graph = compile_numpy_function_into_op_graph_and_measure_bounds(
function_to_trace,
{
param: EncryptedScalar(Integer(32, is_signed=False))
for param in signature(function_to_trace).parameters.keys()
},
[numpy.array(i) for i in range(10)],
default_compilation_configuration,
)
op_graph_not_optimized = compile_numpy_function_into_op_graph_and_measure_bounds(
function_to_trace,
{
param: EncryptedScalar(Integer(32, is_signed=False))
for param in signature(function_to_trace).parameters.keys()
},
[numpy.array(i) for i in range(10)],
CompilationConfiguration(
dump_artifacts_on_unexpected_failures=False,
enable_topological_optimizations=False,
treat_warnings_as_errors=True,
),
)
graph = op_graph.graph
not_optimized_graph = op_graph_not_optimized.graph
if fused:
assert not test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
assert len(graph) < len(not_optimized_graph)
else:
assert test_helpers.digraphs_are_equivalent(graph, not_optimized_graph)
assert len(graph) == len(not_optimized_graph)