chore: be more verbose in this assert

closes #729
This commit is contained in:
Benoit Chevallier-Mames
2021-11-09 17:54:37 +01:00
committed by Benoit Chevallier
parent 955470fb89
commit 1aad4d23d1
2 changed files with 48 additions and 0 deletions

View File

@@ -12,6 +12,7 @@ from ..common.common_helpers import check_op_graph_is_integer_program
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.debugging import format_operation_graph
from ..common.debugging.custom_assert import assert_true
from ..common.fhe_circuit import FHECircuit
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
from ..common.mlir.utils import (
@@ -83,6 +84,23 @@ def _compile_numpy_function_into_op_graph_internal(
OPGraph: compiled function into a graph
"""
# Check function parameters
wrong_inputs = {
inp: function_parameters[inp]
for inp in function_parameters.keys()
if not isinstance(function_parameters[inp], BaseValue)
}
list_of_possible_basevalue = [
"ClearTensor",
"EncryptedTensor",
"ClearScalar",
"EncryptedScalar",
]
assert_true(
len(wrong_inputs.keys()) == 0,
f"wrong type for inputs {wrong_inputs}, needs to be one of {list_of_possible_basevalue}",
)
# Add the function to compile as an artifact
compilation_artifacts.add_function_to_compile(function_to_compile)

View File

@@ -1481,3 +1481,33 @@ def test_fail_compile_with_random_inputset(default_compilation_configuration):
)
assert str(error) == expected
raise
def test_wrong_inputs(default_compilation_configuration):
"""Test compilation with faulty inputs"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
# x should have been something like EncryptedScalar(UnsignedInteger(3))
x = [1, 2, 3]
input_ranges = ((0, 10),)
inputset = data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges))
dict_for_inputs = {"x": x}
with pytest.raises(AssertionError) as excinfo:
compile_numpy_function(
lambda x: 2 * x, dict_for_inputs, inputset, default_compilation_configuration
)
list_of_possible_basevalue = [
"ClearTensor",
"EncryptedTensor",
"ClearScalar",
"EncryptedScalar",
]
assert (
str(excinfo.value) == f"wrong type for inputs {dict_for_inputs}, "
f"needs to be one of {list_of_possible_basevalue}"
)