mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: check inputset size
This commit is contained in:
@@ -12,7 +12,7 @@ def eval_op_graph_bounds_on_inputset(
|
||||
inputset: Iterable[Tuple[Any, ...]],
|
||||
min_func: Callable[[Any, Any], Any] = min,
|
||||
max_func: Callable[[Any, Any], Any] = max,
|
||||
) -> Dict[IntermediateNode, Dict[str, Any]]:
|
||||
) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]:
|
||||
"""Evaluate the bounds with a inputset.
|
||||
|
||||
Evaluate the bounds for all output values of the operators in the graph op_graph over data
|
||||
@@ -31,8 +31,9 @@ def eval_op_graph_bounds_on_inputset(
|
||||
tensors). Defaults to max.
|
||||
|
||||
Returns:
|
||||
Dict[IntermediateNode, Dict[str, Any]]: dict containing the bounds for each node from
|
||||
op_graph, stored with the node as key and a dict with keys "min" and "max" as value.
|
||||
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
|
||||
a dict containing the bounds for each node from op_graph, stored with the node
|
||||
as key and a dict with keys "min" and "max" as value.
|
||||
"""
|
||||
|
||||
def check_inputset_input_len_is_valid(data_to_check):
|
||||
@@ -48,9 +49,12 @@ def eval_op_graph_bounds_on_inputset(
|
||||
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
|
||||
# node expected data type ? Not considering bit_width as they may not make sense at this stage
|
||||
|
||||
inputset_size = 0
|
||||
inputset_iterator = iter(inputset)
|
||||
|
||||
first_input_data = dict(enumerate(next(inputset_iterator)))
|
||||
inputset_size += 1
|
||||
|
||||
check_inputset_input_len_is_valid(first_input_data.values())
|
||||
first_output = op_graph.evaluate(first_input_data)
|
||||
|
||||
@@ -62,6 +66,7 @@ def eval_op_graph_bounds_on_inputset(
|
||||
}
|
||||
|
||||
for input_data in inputset_iterator:
|
||||
inputset_size += 1
|
||||
current_input_data = dict(enumerate(input_data))
|
||||
check_inputset_input_len_is_valid(current_input_data.values())
|
||||
current_output = op_graph.evaluate(current_input_data)
|
||||
@@ -69,4 +74,4 @@ def eval_op_graph_bounds_on_inputset(
|
||||
node_bounds[node]["min"] = min_func(node_bounds[node]["min"], value)
|
||||
node_bounds[node]["max"] = max_func(node_bounds[node]["max"], value)
|
||||
|
||||
return node_bounds
|
||||
return inputset_size, node_bounds
|
||||
|
||||
@@ -108,7 +108,10 @@ def extend_direct_lookup_tables(op_graph: OPGraph):
|
||||
table = node.op_kwargs["table"]
|
||||
bit_width = cast(Integer, node.inputs[0].data_type).bit_width
|
||||
expected_length = 2 ** bit_width
|
||||
if len(table) > expected_length:
|
||||
|
||||
# TODO: remove no cover once the table length workaround is removed
|
||||
# (https://github.com/zama-ai/concretefhe-internal/issues/359)
|
||||
if len(table) > expected_length: # pragma: no cover
|
||||
node.op_kwargs["table"] = table[:expected_length]
|
||||
else:
|
||||
repeat = expected_length // len(table)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""numpy compilation function."""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
@@ -9,6 +10,7 @@ from zamalang import CompilerEngine
|
||||
from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset
|
||||
from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from ..common.mlir.utils import (
|
||||
extend_direct_lookup_tables,
|
||||
@@ -107,13 +109,36 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
)
|
||||
|
||||
# Find bounds with the inputset
|
||||
node_bounds = eval_op_graph_bounds_on_inputset(
|
||||
inputset_size, node_bounds = eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
)
|
||||
|
||||
# Check inputset size
|
||||
inputset_size_upper_limit = 1
|
||||
|
||||
# this loop will determine the number of possible inputs of the function
|
||||
# if a function have a single 3-bit input, for example, `inputset_size_upper_limit` will be 8
|
||||
for parameter_value in function_parameters.values():
|
||||
if isinstance(parameter_value.data_type, Integer):
|
||||
# multiple parameter bit-widths are multiplied as they can be combined into an input
|
||||
inputset_size_upper_limit *= 2 ** parameter_value.data_type.bit_width
|
||||
|
||||
# if the upper limit of the inputset size goes above 10,
|
||||
# break the loop as we will require at least 10 inputs in this case
|
||||
if inputset_size_upper_limit > 10:
|
||||
break
|
||||
|
||||
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
|
||||
if inputset_size < minimum_required_inputset_size:
|
||||
sys.stderr.write(
|
||||
f"Provided inputset contains too few inputs "
|
||||
f"(it should have had at least {minimum_required_inputset_size} "
|
||||
f"but it only had {inputset_size})\n"
|
||||
)
|
||||
|
||||
# Add the bounds as an artifact
|
||||
compilation_artifacts.add_final_operation_graph_bounds(node_bounds)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ y = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
In this configuration, both `x` and `y` are 3-bit unsigned integers, so they have the range of `[0, 2**3 - 1]`
|
||||
|
||||
We also need an inputset. This latter is not to be confused with the dataset, which is used in training and contains labels. It is to determine the bit-widths of the intermediate results so only the inputs are necessary. It should be an iterable yielding tuples in the same order as the inputs of the function to compile.
|
||||
We also need an inputset. It is to determine the bit-widths of the intermediate results. It should be an iterable yielding tuples in the same order as the inputs of the function to compile. There should be at least 10 inputs in the input set to avoid warnings (except for functions with less than 10 possible inputs). The warning is there because the bigger the input set, the better the bounds will be.
|
||||
|
||||
```python
|
||||
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1)]
|
||||
|
||||
@@ -281,7 +281,7 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
|
||||
for y_gen in range_y:
|
||||
yield (x_gen, y_gen)
|
||||
|
||||
node_bounds = eval_op_graph_bounds_on_inputset(
|
||||
_, node_bounds = eval_op_graph_bounds_on_inputset(
|
||||
op_graph, data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges))
|
||||
)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def test_artifacts_export():
|
||||
compile_numpy_function(
|
||||
function,
|
||||
{"x": EncryptedScalar(UnsignedInteger(7))},
|
||||
[(0,), (1,), (2,)],
|
||||
[(i,) for i in range(10)],
|
||||
compilation_artifacts=artifacts,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ def test_enable_topological_optimizations(test_helpers, function_to_trace, fused
|
||||
param: EncryptedScalar(Integer(32, is_signed=False))
|
||||
for param in signature(function_to_trace).parameters.keys()
|
||||
},
|
||||
[(1,), (2,), (3,)],
|
||||
[(i,) for i in range(10)],
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
)
|
||||
op_graph_not_optimized = compile_numpy_function_into_op_graph(
|
||||
@@ -58,7 +58,7 @@ def test_enable_topological_optimizations(test_helpers, function_to_trace, fused
|
||||
param: EncryptedScalar(Integer(32, is_signed=False))
|
||||
for param in signature(function_to_trace).parameters.keys()
|
||||
},
|
||||
[(1,), (2,), (3,)],
|
||||
[(i,) for i in range(10)],
|
||||
CompilationConfiguration(
|
||||
dump_artifacts_on_unexpected_failures=False,
|
||||
enable_topological_optimizations=False,
|
||||
|
||||
@@ -18,7 +18,7 @@ def test_draw_graph_with_saving():
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(-2,), (-1,), (0,), (1,), (2,)],
|
||||
[(i,) for i in range(-5, 5)],
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
|
||||
@@ -33,7 +33,7 @@ def sub(x, y):
|
||||
|
||||
def constant_sub(x):
|
||||
"""Test constant sub"""
|
||||
return 8 - x
|
||||
return 12 - x
|
||||
|
||||
|
||||
def mul(x, y):
|
||||
@@ -108,7 +108,7 @@ def datagen(*args):
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
(range(0, 10),),
|
||||
),
|
||||
(
|
||||
add,
|
||||
@@ -139,7 +139,7 @@ def datagen(*args):
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 5),),
|
||||
(range(0, 10),),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
@@ -154,7 +154,7 @@ def datagen(*args):
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
(range(0, 10),),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
@@ -194,7 +194,7 @@ def datagen(*args):
|
||||
(
|
||||
lut,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
"x": EncryptedScalar(Integer(3, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
@@ -209,7 +209,7 @@ def datagen(*args):
|
||||
(
|
||||
lut_less_bits_than_table_length,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
"x": EncryptedScalar(Integer(3, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
|
||||
@@ -44,9 +44,9 @@ def small_fused_table(x):
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 42, ((-2, 2),), ["x"]),
|
||||
pytest.param(lambda x: x + 42, ((-5, 5),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
|
||||
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 4)), ["x", "y"]),
|
||||
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 8)), ["x", "y"]),
|
||||
pytest.param(
|
||||
lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99),
|
||||
((4, 8), (3, 4), (0, 4)),
|
||||
@@ -89,10 +89,10 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 42, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: x + numpy.int32(42), ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: x * 2, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]),
|
||||
pytest.param(lambda x: x + 42, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x + numpy.int32(42), ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x * 2, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: 12 - x, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
|
||||
pytest.param(lut, ((0, 127),), ["x"]),
|
||||
pytest.param(small_lut, ((0, 31),), ["x"]),
|
||||
@@ -240,7 +240,7 @@ def test_compile_function_with_direct_tlu_overflow():
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x - 10, ((-2, 2),), ["x"]),
|
||||
pytest.param(lambda x: x - 10, ((-5, 5),), ["x"]),
|
||||
],
|
||||
)
|
||||
def test_fail_compile(function, input_ranges, list_of_arg_names):
|
||||
@@ -263,6 +263,16 @@ def test_fail_compile(function, input_ranges, list_of_arg_names):
|
||||
)
|
||||
|
||||
|
||||
def test_small_inputset():
|
||||
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
|
||||
compile_numpy_function_into_op_graph(
|
||||
lambda x: x + 42,
|
||||
{"x": EncryptedScalar(Integer(5, is_signed=False))},
|
||||
[(0,), (3,)],
|
||||
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,params,shape,ref_graph_str",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user