mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
committed by
Benoit Chevallier
parent
955470fb89
commit
1aad4d23d1
@@ -12,6 +12,7 @@ from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.debugging import format_operation_graph
|
||||
from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
|
||||
from ..common.mlir.utils import (
|
||||
@@ -83,6 +84,23 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
OPGraph: compiled function into a graph
|
||||
"""
|
||||
|
||||
# Check function parameters
|
||||
wrong_inputs = {
|
||||
inp: function_parameters[inp]
|
||||
for inp in function_parameters.keys()
|
||||
if not isinstance(function_parameters[inp], BaseValue)
|
||||
}
|
||||
list_of_possible_basevalue = [
|
||||
"ClearTensor",
|
||||
"EncryptedTensor",
|
||||
"ClearScalar",
|
||||
"EncryptedScalar",
|
||||
]
|
||||
assert_true(
|
||||
len(wrong_inputs.keys()) == 0,
|
||||
f"wrong type for inputs {wrong_inputs}, needs to be one of {list_of_possible_basevalue}",
|
||||
)
|
||||
|
||||
# Add the function to compile as an artifact
|
||||
compilation_artifacts.add_function_to_compile(function_to_compile)
|
||||
|
||||
|
||||
@@ -1481,3 +1481,33 @@ def test_fail_compile_with_random_inputset(default_compilation_configuration):
|
||||
)
|
||||
assert str(error) == expected
|
||||
raise
|
||||
|
||||
|
||||
def test_wrong_inputs(default_compilation_configuration):
|
||||
"""Test compilation with faulty inputs"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
# x should have been something like EncryptedScalar(UnsignedInteger(3))
|
||||
x = [1, 2, 3]
|
||||
input_ranges = ((0, 10),)
|
||||
inputset = data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges))
|
||||
dict_for_inputs = {"x": x}
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
compile_numpy_function(
|
||||
lambda x: 2 * x, dict_for_inputs, inputset, default_compilation_configuration
|
||||
)
|
||||
|
||||
list_of_possible_basevalue = [
|
||||
"ClearTensor",
|
||||
"EncryptedTensor",
|
||||
"ClearScalar",
|
||||
"EncryptedScalar",
|
||||
]
|
||||
assert (
|
||||
str(excinfo.value) == f"wrong type for inputs {dict_for_inputs}, "
|
||||
f"needs to be one of {list_of_possible_basevalue}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user