feat: check inputset size

This commit is contained in:
Umut
2021-09-20 12:35:46 +03:00
parent f4d7cab359
commit 441c4f9e7d
10 changed files with 68 additions and 25 deletions

View File

@@ -12,7 +12,7 @@ def eval_op_graph_bounds_on_inputset(
inputset: Iterable[Tuple[Any, ...]],
min_func: Callable[[Any, Any], Any] = min,
max_func: Callable[[Any, Any], Any] = max,
) -> Dict[IntermediateNode, Dict[str, Any]]:
) -> Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]:
"""Evaluate the bounds with a inputset.
Evaluate the bounds for all output values of the operators in the graph op_graph over data
@@ -31,8 +31,9 @@ def eval_op_graph_bounds_on_inputset(
tensors). Defaults to max.
Returns:
Dict[IntermediateNode, Dict[str, Any]]: dict containing the bounds for each node from
op_graph, stored with the node as key and a dict with keys "min" and "max" as value.
Tuple[int, Dict[IntermediateNode, Dict[str, Any]]]: number of inputs in the inputset and
a dict containing the bounds for each node from op_graph, stored with the node
as key and a dict with keys "min" and "max" as value.
"""
def check_inputset_input_len_is_valid(data_to_check):
@@ -48,9 +49,12 @@ def eval_op_graph_bounds_on_inputset(
# TODO: do we want to check coherence between the input data type and the corresponding Input ir
# node expected data type ? Not considering bit_width as they may not make sense at this stage
inputset_size = 0
inputset_iterator = iter(inputset)
first_input_data = dict(enumerate(next(inputset_iterator)))
inputset_size += 1
check_inputset_input_len_is_valid(first_input_data.values())
first_output = op_graph.evaluate(first_input_data)
@@ -62,6 +66,7 @@ def eval_op_graph_bounds_on_inputset(
}
for input_data in inputset_iterator:
inputset_size += 1
current_input_data = dict(enumerate(input_data))
check_inputset_input_len_is_valid(current_input_data.values())
current_output = op_graph.evaluate(current_input_data)
@@ -69,4 +74,4 @@ def eval_op_graph_bounds_on_inputset(
node_bounds[node]["min"] = min_func(node_bounds[node]["min"], value)
node_bounds[node]["max"] = max_func(node_bounds[node]["max"], value)
return node_bounds
return inputset_size, node_bounds

View File

@@ -108,7 +108,10 @@ def extend_direct_lookup_tables(op_graph: OPGraph):
table = node.op_kwargs["table"]
bit_width = cast(Integer, node.inputs[0].data_type).bit_width
expected_length = 2 ** bit_width
if len(table) > expected_length:
# TODO: remove no cover once the table length workaround is removed
# (https://github.com/zama-ai/concretefhe-internal/issues/359)
if len(table) > expected_length: # pragma: no cover
node.op_kwargs["table"] = table[:expected_length]
else:
repeat = expected_length // len(table)

View File

@@ -1,5 +1,6 @@
"""numpy compilation function."""
import sys
import traceback
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
@@ -9,6 +10,7 @@ from zamalang import CompilerEngine
from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_inputset
from ..common.common_helpers import check_op_graph_is_integer_program
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from ..common.mlir.utils import (
extend_direct_lookup_tables,
@@ -107,13 +109,36 @@ def _compile_numpy_function_into_op_graph_internal(
)
# Find bounds with the inputset
node_bounds = eval_op_graph_bounds_on_inputset(
inputset_size, node_bounds = eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
min_func=numpy_min_func,
max_func=numpy_max_func,
)
# Check inputset size
inputset_size_upper_limit = 1
# this loop will determine the number of possible inputs of the function
# if a function have a single 3-bit input, for example, `inputset_size_upper_limit` will be 8
for parameter_value in function_parameters.values():
if isinstance(parameter_value.data_type, Integer):
# multiple parameter bit-widths are multiplied as they can be combined into an input
inputset_size_upper_limit *= 2 ** parameter_value.data_type.bit_width
# if the upper limit of the inputset size goes above 10,
# break the loop as we will require at least 10 inputs in this case
if inputset_size_upper_limit > 10:
break
minimum_required_inputset_size = min(inputset_size_upper_limit, 10)
if inputset_size < minimum_required_inputset_size:
sys.stderr.write(
f"Provided inputset contains too few inputs "
f"(it should have had at least {minimum_required_inputset_size} "
f"but it only had {inputset_size})\n"
)
# Add the bounds as an artifact
compilation_artifacts.add_final_operation_graph_bounds(node_bounds)

View File

@@ -28,7 +28,7 @@ y = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
In this configuration, both `x` and `y` are 3-bit unsigned integers, so they have the range of `[0, 2**3 - 1]`
We also need an inputset. This latter is not to be confused with the dataset, which is used in training and contains labels. It is to determine the bit-widths of the intermediate results so only the inputs are necessary. It should be an iterable yielding tuples in the same order as the inputs of the function to compile.
We also need an inputset. It is to determine the bit-widths of the intermediate results. It should be an iterable yielding tuples in the same order as the inputs of the function to compile. There should be at least 10 inputs in the input set to avoid warnings (except for functions with less than 10 possible inputs). The warning is there because the bigger the input set, the better the bounds will be.
```python
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1)]

View File

@@ -281,7 +281,7 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
for y_gen in range_y:
yield (x_gen, y_gen)
node_bounds = eval_op_graph_bounds_on_inputset(
_, node_bounds = eval_op_graph_bounds_on_inputset(
op_graph, data_gen(*tuple(range(x[0], x[1] + 1) for x in input_ranges))
)

View File

@@ -22,7 +22,7 @@ def test_artifacts_export():
compile_numpy_function(
function,
{"x": EncryptedScalar(UnsignedInteger(7))},
[(0,), (1,), (2,)],
[(i,) for i in range(10)],
compilation_artifacts=artifacts,
)

View File

@@ -49,7 +49,7 @@ def test_enable_topological_optimizations(test_helpers, function_to_trace, fused
param: EncryptedScalar(Integer(32, is_signed=False))
for param in signature(function_to_trace).parameters.keys()
},
[(1,), (2,), (3,)],
[(i,) for i in range(10)],
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
)
op_graph_not_optimized = compile_numpy_function_into_op_graph(
@@ -58,7 +58,7 @@ def test_enable_topological_optimizations(test_helpers, function_to_trace, fused
param: EncryptedScalar(Integer(32, is_signed=False))
for param in signature(function_to_trace).parameters.keys()
},
[(1,), (2,), (3,)],
[(i,) for i in range(10)],
CompilationConfiguration(
dump_artifacts_on_unexpected_failures=False,
enable_topological_optimizations=False,

View File

@@ -18,7 +18,7 @@ def test_draw_graph_with_saving():
op_graph = compile_numpy_function_into_op_graph(
function,
{"x": EncryptedScalar(Integer(7, True))},
[(-2,), (-1,), (0,), (1,), (2,)],
[(i,) for i in range(-5, 5)],
)
with tempfile.TemporaryDirectory() as tmp:

View File

@@ -33,7 +33,7 @@ def sub(x, y):
def constant_sub(x):
"""Test constant sub"""
return 8 - x
return 12 - x
def mul(x, y):
@@ -108,7 +108,7 @@ def datagen(*args):
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 8),),
(range(0, 10),),
),
(
add,
@@ -139,7 +139,7 @@ def datagen(*args):
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 5),),
(range(0, 10),),
),
(
mul,
@@ -154,7 +154,7 @@ def datagen(*args):
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 8),),
(range(0, 10),),
),
(
mul,
@@ -194,7 +194,7 @@ def datagen(*args):
(
lut,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
"x": EncryptedScalar(Integer(3, is_signed=False)),
},
(range(0, 8),),
),
@@ -209,7 +209,7 @@ def datagen(*args):
(
lut_less_bits_than_table_length,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
"x": EncryptedScalar(Integer(3, is_signed=False)),
},
(range(0, 8),),
),

View File

@@ -44,9 +44,9 @@ def small_fused_table(x):
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 42, ((-2, 2),), ["x"]),
pytest.param(lambda x: x + 42, ((-5, 5),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 4)), ["x", "y"]),
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 8)), ["x", "y"]),
pytest.param(
lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99),
((4, 8), (3, 4), (0, 4)),
@@ -89,10 +89,10 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 42, ((0, 2),), ["x"]),
pytest.param(lambda x: x + numpy.int32(42), ((0, 2),), ["x"]),
pytest.param(lambda x: x * 2, ((0, 2),), ["x"]),
pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]),
pytest.param(lambda x: x + 42, ((0, 10),), ["x"]),
pytest.param(lambda x: x + numpy.int32(42), ((0, 10),), ["x"]),
pytest.param(lambda x: x * 2, ((0, 10),), ["x"]),
pytest.param(lambda x: 12 - x, ((0, 10),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
pytest.param(lut, ((0, 127),), ["x"]),
pytest.param(small_lut, ((0, 31),), ["x"]),
@@ -240,7 +240,7 @@ def test_compile_function_with_direct_tlu_overflow():
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x - 10, ((-2, 2),), ["x"]),
pytest.param(lambda x: x - 10, ((-5, 5),), ["x"]),
],
)
def test_fail_compile(function, input_ranges, list_of_arg_names):
@@ -263,6 +263,16 @@ def test_fail_compile(function, input_ranges, list_of_arg_names):
)
def test_small_inputset():
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
compile_numpy_function_into_op_graph(
lambda x: x + 42,
{"x": EncryptedScalar(Integer(5, is_signed=False))},
[(0,), (3,)],
CompilationConfiguration(dump_artifacts_on_unexpected_failures=False),
)
@pytest.mark.parametrize(
"function,params,shape,ref_graph_str",
[