mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(compilation): provide the reason for MLIR incompatibility
This commit is contained in:
@@ -71,6 +71,30 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def value_is_scalar(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is a scalar.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a scalar
|
||||
"""
|
||||
return isinstance(value_to_check, TensorValue) and value_to_check.is_scalar
|
||||
|
||||
|
||||
def value_is_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is an encrypted TensorValue of type Integer.
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Utilities for MLIR conversion."""
|
||||
from typing import cast
|
||||
from typing import Dict, Optional, cast
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
@@ -7,32 +7,41 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
value_is_scalar_integer,
|
||||
value_is_integer,
|
||||
value_is_scalar,
|
||||
)
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import UnivariateFunction
|
||||
from ..representation.intermediate import IntermediateNode, UnivariateFunction
|
||||
|
||||
# TODO: should come from compiler, through an API, #402
|
||||
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
|
||||
|
||||
|
||||
def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
|
||||
def check_graph_values_compatibility_with_mlir(
|
||||
op_graph: OPGraph,
|
||||
) -> Optional[Dict[IntermediateNode, str]]:
|
||||
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
|
||||
|
||||
Args:
|
||||
op_graph: computation graph to check
|
||||
|
||||
Returns:
|
||||
bool: is the graph compatible with the expected MLIR representation
|
||||
Dict[IntermediateNode, str]: None if the graph is compatible
|
||||
information about offending nodes otherwise
|
||||
"""
|
||||
return all(
|
||||
all(
|
||||
value_is_scalar_integer(out) and not cast(Integer, out.dtype).is_signed
|
||||
for out in out_node.outputs
|
||||
)
|
||||
for out_node in op_graph.output_nodes.values()
|
||||
)
|
||||
|
||||
offending_nodes = {}
|
||||
|
||||
for out_node in op_graph.output_nodes.values():
|
||||
for out in out_node.outputs:
|
||||
if not value_is_scalar(out):
|
||||
offending_nodes[out_node] = "non scalar outputs aren't supported"
|
||||
|
||||
if value_is_integer(out) and cast(Integer, out.dtype).is_signed:
|
||||
offending_nodes[out_node] = "signed integer outputs aren't supported"
|
||||
|
||||
return None if len(offending_nodes) == 0 else offending_nodes
|
||||
|
||||
|
||||
def _set_all_bit_width(op_graph: OPGraph, p: int):
|
||||
|
||||
@@ -11,11 +11,12 @@ from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_in
|
||||
from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.debugging import get_printable_graph
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from ..common.mlir.utils import (
|
||||
check_graph_values_compatibility_with_mlir,
|
||||
extend_direct_lookup_tables,
|
||||
is_graph_values_compatible_with_mlir,
|
||||
update_bit_width_for_mlir,
|
||||
)
|
||||
from ..common.operator_graph import OPGraph
|
||||
@@ -162,8 +163,12 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
compilation_artifacts.add_operation_graph("final", op_graph)
|
||||
|
||||
# Make sure the graph can be lowered to MLIR
|
||||
if not is_graph_values_compatible_with_mlir(op_graph):
|
||||
raise RuntimeError("function you are trying to compile isn't supported for MLIR lowering")
|
||||
offending_nodes = check_graph_values_compatibility_with_mlir(op_graph)
|
||||
if offending_nodes is not None:
|
||||
raise RuntimeError(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n\n"
|
||||
+ get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
# Update bit_width for MLIR
|
||||
update_bit_width_for_mlir(op_graph)
|
||||
|
||||
@@ -560,29 +560,50 @@ def test_compile_function_with_direct_tlu_overflow():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
"function,parameters,inputset,match",
|
||||
[
|
||||
pytest.param(lambda x: x - 10, ((-5, 5),), ["x"]),
|
||||
pytest.param(
|
||||
lambda x: 1 - x,
|
||||
{"x": EncryptedScalar(Integer(3, is_signed=False))},
|
||||
[(i,) for i in range(8)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501
|
||||
"%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>\n" # noqa: E501
|
||||
"%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integer outputs aren't supported\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 8, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501
|
||||
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501
|
||||
"%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ non scalar outputs aren't supported\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_compile(function, input_ranges, list_of_arg_names):
|
||||
def test_fail_compile(function, parameters, inputset, match):
|
||||
"""Test function compile_numpy_function_into_op_graph for a program with signed values"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, True)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
with pytest.raises(RuntimeError, match=".*isn't supported for MLIR lowering.*"):
|
||||
try:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
parameters,
|
||||
inputset,
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
)
|
||||
except RuntimeError as error:
|
||||
assert str(error) == match
|
||||
|
||||
|
||||
def test_small_inputset():
|
||||
|
||||
Reference in New Issue
Block a user