feat(compilation): provide the reason for MLIR incompatibility

This commit is contained in:
Umut
2021-10-15 11:36:54 +03:00
parent 4c9c49ecd2
commit 73769b917e
4 changed files with 88 additions and 29 deletions

View File

@@ -71,6 +71,30 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
)
def value_is_scalar(value_to_check: BaseValue) -> bool:
"""Check that a value is a scalar.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a scalar
"""
return isinstance(value_to_check, TensorValue) and value_to_check.is_scalar
def value_is_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is of type Integer
"""
return isinstance(value_to_check.dtype, INTEGER_TYPES)
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is an encrypted TensorValue of type Integer.

View File

@@ -1,5 +1,5 @@
"""Utilities for MLIR conversion."""
from typing import cast
from typing import Dict, Optional, cast
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
@@ -7,32 +7,41 @@ from ..data_types.dtypes_helpers import (
value_is_clear_tensor_integer,
value_is_encrypted_scalar_integer,
value_is_encrypted_tensor_integer,
value_is_scalar_integer,
value_is_integer,
value_is_scalar,
)
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import UnivariateFunction
from ..representation.intermediate import IntermediateNode, UnivariateFunction
# TODO: should come from compiler, through an API, #402
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
def check_graph_values_compatibility_with_mlir(
op_graph: OPGraph,
) -> Optional[Dict[IntermediateNode, str]]:
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
Args:
op_graph: computation graph to check
Returns:
bool: is the graph compatible with the expected MLIR representation
Dict[IntermediateNode, str]: None if the graph is compatible
information about offending nodes otherwise
"""
return all(
all(
value_is_scalar_integer(out) and not cast(Integer, out.dtype).is_signed
for out in out_node.outputs
)
for out_node in op_graph.output_nodes.values()
)
offending_nodes = {}
for out_node in op_graph.output_nodes.values():
for out in out_node.outputs:
if not value_is_scalar(out):
offending_nodes[out_node] = "non scalar outputs aren't supported"
if value_is_integer(out) and cast(Integer, out.dtype).is_signed:
offending_nodes[out_node] = "signed integer outputs aren't supported"
return None if len(offending_nodes) == 0 else offending_nodes
def _set_all_bit_width(op_graph: OPGraph, p: int):

View File

@@ -11,11 +11,12 @@ from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_in
from ..common.common_helpers import check_op_graph_is_integer_program
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.debugging import get_printable_graph
from ..common.fhe_circuit import FHECircuit
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from ..common.mlir.utils import (
check_graph_values_compatibility_with_mlir,
extend_direct_lookup_tables,
is_graph_values_compatible_with_mlir,
update_bit_width_for_mlir,
)
from ..common.operator_graph import OPGraph
@@ -162,8 +163,12 @@ def _compile_numpy_function_into_op_graph_internal(
compilation_artifacts.add_operation_graph("final", op_graph)
# Make sure the graph can be lowered to MLIR
if not is_graph_values_compatible_with_mlir(op_graph):
raise RuntimeError("function you are trying to compile isn't supported for MLIR lowering")
offending_nodes = check_graph_values_compatibility_with_mlir(op_graph)
if offending_nodes is not None:
raise RuntimeError(
"function you are trying to compile isn't supported for MLIR lowering\n\n"
+ get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes)
)
# Update bit_width for MLIR
update_bit_width_for_mlir(op_graph)

View File

@@ -560,29 +560,50 @@ def test_compile_function_with_direct_tlu_overflow():
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
"function,parameters,inputset,match",
[
pytest.param(lambda x: x - 10, ((-5, 5),), ["x"]),
pytest.param(
lambda x: 1 - x,
{"x": EncryptedScalar(Integer(3, is_signed=False))},
[(i,) for i in range(8)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501
"%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>\n" # noqa: E501
"%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integer outputs aren't supported\n" # noqa: E501
"return(%2)\n"
),
),
pytest.param(
lambda x: x + 1,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
[(numpy.random.randint(0, 8, size=(2, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501
"%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ non scalar outputs aren't supported\n" # noqa: E501
"return(%2)\n"
),
),
],
)
def test_fail_compile(function, input_ranges, list_of_arg_names):
def test_fail_compile(function, parameters, inputset, match):
"""Test function compile_numpy_function_into_op_graph for a program with signed values"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, True)) for arg_name in list_of_arg_names
}
with pytest.raises(RuntimeError, match=".*isn't supported for MLIR lowering.*"):
try:
compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
parameters,
inputset,
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
)
except RuntimeError as error:
assert str(error) == match
def test_small_inputset():