fix: check widths are supported by concrete-lib

and if not, explain to the user
refs #139
This commit is contained in:
Benoit Chevallier-Mames
2021-09-16 17:42:45 +02:00
committed by Benoit Chevallier
parent 25d40a4348
commit 6a83b01e92
2 changed files with 71 additions and 5 deletions

View File

@@ -9,9 +9,13 @@ from ..data_types.dtypes_helpers import (
value_is_encrypted_tensor_integer,
value_is_scalar_integer,
)
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import ArbitraryFunction
# TODO: should come from compiler, through an API, #402
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
@@ -55,16 +59,36 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
op_graph: graph to update bit_width for
"""
max_bit_width = 0
offending_list = []
for node in op_graph.graph.nodes:
for value_out in node.outputs:
if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out):
max_bit_width = max(max_bit_width, value_out.data_type.bit_width - 1)
elif value_is_encrypted_scalar_integer(value_out) or value_is_encrypted_tensor_integer(
value_out
):
max_bit_width = max(max_bit_width, value_out.data_type.bit_width)
current_node_out_bit_width = value_out.data_type.bit_width - 1
else:
assert_true(
value_is_encrypted_scalar_integer(value_out)
or value_is_encrypted_tensor_integer(value_out)
)
current_node_out_bit_width = value_out.data_type.bit_width
max_bit_width = max(max_bit_width, current_node_out_bit_width)
# Check that current_node_out_bit_width is supported by the compiler
if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB:
offending_list.append((node, current_node_out_bit_width))
_set_all_bit_width(op_graph, max_bit_width)
# Check that the max_bit_width is supported by the compiler
if len(offending_list) != 0:
raise RuntimeError(
f"max_bit_width of some nodes is too high for the current version of "
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB} "
f"which is not compatible with {offending_list})"
)
def extend_direct_lookup_tables(op_graph: OPGraph):
"""Extend direct lookup tables to the maximum length the input bit width can support.

View File

@@ -330,3 +330,45 @@ def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
show_mlir=True,
)
def test_compile_too_high_bitwidth():
"""Check that the check of maximal bitwidth of intermediate data works fine."""
def function(x, y):
return x + y
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
"x": EncryptedScalar(Integer(64, False)),
"y": EncryptedScalar(Integer(64, False)),
}
# A bit too much
input_ranges = [(0, 100), (0, 28)]
with pytest.raises(RuntimeError) as excinfo:
compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)
assert (
"max_bit_width of some nodes is too high for the current version of the "
"compiler (maximum must be 7 which is not compatible with" in str(excinfo.value)
)
assert str(excinfo.value).endswith(", 8)])")
# Just ok
input_ranges = [(0, 99), (0, 28)]
compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)