mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix: check widths are supported by concrete-lib
and if not, explain to the user refs #139
This commit is contained in:
committed by
Benoit Chevallier
parent
25d40a4348
commit
6a83b01e92
@@ -9,9 +9,13 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_encrypted_tensor_integer,
|
||||
value_is_scalar_integer,
|
||||
)
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import ArbitraryFunction
|
||||
|
||||
# TODO: should come from compiler, through an API, #402
|
||||
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
|
||||
|
||||
|
||||
def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
|
||||
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
|
||||
@@ -55,16 +59,36 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
op_graph: graph to update bit_width for
|
||||
"""
|
||||
max_bit_width = 0
|
||||
offending_list = []
|
||||
for node in op_graph.graph.nodes:
|
||||
for value_out in node.outputs:
|
||||
if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out):
|
||||
max_bit_width = max(max_bit_width, value_out.data_type.bit_width - 1)
|
||||
elif value_is_encrypted_scalar_integer(value_out) or value_is_encrypted_tensor_integer(
|
||||
value_out
|
||||
):
|
||||
max_bit_width = max(max_bit_width, value_out.data_type.bit_width)
|
||||
current_node_out_bit_width = value_out.data_type.bit_width - 1
|
||||
else:
|
||||
|
||||
assert_true(
|
||||
value_is_encrypted_scalar_integer(value_out)
|
||||
or value_is_encrypted_tensor_integer(value_out)
|
||||
)
|
||||
|
||||
current_node_out_bit_width = value_out.data_type.bit_width
|
||||
|
||||
max_bit_width = max(max_bit_width, current_node_out_bit_width)
|
||||
|
||||
# Check that current_node_out_bit_width is supported by the compiler
|
||||
if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB:
|
||||
offending_list.append((node, current_node_out_bit_width))
|
||||
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
|
||||
# Check that the max_bit_width is supported by the compiler
|
||||
if len(offending_list) != 0:
|
||||
raise RuntimeError(
|
||||
f"max_bit_width of some nodes is too high for the current version of "
|
||||
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB} "
|
||||
f"which is not compatible with {offending_list})"
|
||||
)
|
||||
|
||||
|
||||
def extend_direct_lookup_tables(op_graph: OPGraph):
|
||||
"""Extend direct lookup tables to the maximum length the input bit width can support.
|
||||
|
||||
@@ -330,3 +330,45 @@ def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
show_mlir=True,
|
||||
)
|
||||
|
||||
|
||||
def test_compile_too_high_bitwidth():
|
||||
"""Check that the check of maximal bitwidth of intermediate data works fine."""
|
||||
|
||||
def function(x, y):
|
||||
return x + y
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
"x": EncryptedScalar(Integer(64, False)),
|
||||
"y": EncryptedScalar(Integer(64, False)),
|
||||
}
|
||||
|
||||
# A bit too much
|
||||
input_ranges = [(0, 100), (0, 28)]
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
assert (
|
||||
"max_bit_width of some nodes is too high for the current version of the "
|
||||
"compiler (maximum must be 7 which is not compatible with" in str(excinfo.value)
|
||||
)
|
||||
|
||||
assert str(excinfo.value).endswith(", 8)])")
|
||||
|
||||
# Just ok
|
||||
input_ranges = [(0, 99), (0, 28)]
|
||||
|
||||
compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user