diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 77bd5f947..e83af0ae1 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -1,13 +1,14 @@ """File holding topological optimization/simplification code.""" +import itertools from copy import deepcopy -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, cast import networkx as nx from ..compilation.artifacts import CompilationArtifacts from ..data_types.floats import Float from ..data_types.integers import Integer -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true, custom_assert from ..operator_graph import OPGraph from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode from ..values import TensorValue @@ -39,10 +40,6 @@ def fuse_float_operations( float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result processed_terminal_nodes.add(terminal_node) - # TODO: #199 To be removed when doing tensor management - if not subgraph_is_scalar_only(subgraph_all_nodes): - continue - subgraph_conversion_result = convert_float_subgraph_to_fused_node( op_graph, float_subgraph_start_nodes, @@ -111,16 +108,20 @@ def convert_float_subgraph_to_fused_node( output must be plugged as the input to the subgraph. """ - if not subgraph_has_unique_variable_input(float_subgraph_start_nodes): + subgraph_can_be_fused = subgraph_has_unique_variable_input( + float_subgraph_start_nodes + ) and subgraph_values_allow_fusing(float_subgraph_start_nodes, subgraph_all_nodes) + + if not subgraph_can_be_fused: return None # Only one variable input node, find which node feeds its input - non_constant_start_nodes = [ + variable_input_nodes = [ node for node in float_subgraph_start_nodes if not isinstance(node, Constant) ] - custom_assert(len(non_constant_start_nodes) == 1) + custom_assert(len(variable_input_nodes) == 1) - current_subgraph_variable_input = non_constant_start_nodes[0] + current_subgraph_variable_input = variable_input_nodes[0] new_input_value = deepcopy(current_subgraph_variable_input.outputs[0]) nx_graph = op_graph.graph @@ -244,20 +245,89 @@ def find_float_subgraph_with_unique_terminal_node( return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes -# TODO: #199 To be removed when doing tensor management -def subgraph_is_scalar_only(subgraph_all_nodes: Set[IntermediateNode]) -> bool: - """Check subgraph only processes scalars. +def subgraph_values_allow_fusing( + float_subgraph_start_nodes: Set[IntermediateNode], + subgraph_all_nodes: Set[IntermediateNode], +): + """Check if a subgraph's values are compatible with fusing. + + A fused subgraph for example only works on an input tensor if the resulting ArbitraryFunction + can be applied per cell, hence shuffling or tensor shape changes make fusing impossible. Args: - subgraph_all_nodes (Set[IntermediateNode]): The nodes of the float subgraph. + float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph. + subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph. Returns: - bool: True if all inputs and outputs of the nodes in the subgraph are scalars. + bool: True if all inputs and outputs of the nodes in the subgraph are compatible with fusing + i.e. outputs have the same shapes equal to the variable input. """ - return all( - all(isinstance(input_, TensorValue) and input_.is_scalar for input_ in node.inputs) - and all(isinstance(output, TensorValue) and output.is_scalar for output in node.outputs) + + variable_input_nodes = [ + node for node in float_subgraph_start_nodes if not isinstance(node, Constant) + ] + + assert_true( + (num_variable_input_nodes := len(variable_input_nodes)) == 1, + f"{subgraph_values_allow_fusing.__name__} " + f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}", + ) + + # Some ArbitraryFunction nodes have baked constants that need to be taken into account for the + # max size computation + baked_constants_ir_nodes = [ + baked_constant_base_value for node in subgraph_all_nodes + if isinstance(node, ArbitraryFunction) + if (baked_constant_base_value := node.op_attributes.get("baked_constant_ir_node", None)) + is not None + ] + + all_values_are_tensors = all( + all(isinstance(input_, TensorValue) for input_ in node.inputs) + and all(isinstance(output, TensorValue) for output in node.outputs) + for node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes) + ) + + if not all_values_are_tensors: + # This cannot be reached today as scalars are Tensors with shape == () (numpy convention) + return False # pragma: no cover + + variable_input_node = variable_input_nodes[0] + + # A cheap check is that the variable input node must have the biggest size, i.e. have the most + # elements, meaning all constants will broadcast to its shape. This is because the + # ArbitraryFunction input and output must have the same shape so that it can be applied to each + # of the input tensor cells. + # There *may* be a way to manage the other case by simulating the broadcast of the smaller input + # array and then concatenating/stacking the results. This is not currently doable as we don't + # have a concatenate operator on the compiler side. + # TODO: #587 https://github.com/zama-ai/concretefhe-internal/issues/587 + + variable_input_node_output = cast(TensorValue, variable_input_node.outputs[0]) + variable_input_node_output_size, variable_input_node_output_shape = ( + variable_input_node_output.size, + variable_input_node_output.shape, + ) + max_inputs_size = max( + cast(TensorValue, input_node.outputs[0]).size + for input_node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes) + ) + + if variable_input_node_output_size < max_inputs_size: + return False + + # Now that we know the variable input node has the biggest size we can check shapes are + # consistent throughout the subgraph: outputs of ir nodes that are not constant must be equal. + + non_constant_nodes = (node for node in subgraph_all_nodes if not isinstance(node, Constant)) + + return all( + all( + isinstance(output, TensorValue) and output.shape == variable_input_node_output_shape + for output in node.outputs + ) + for node in non_constant_nodes ) diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index da66a20d5..e5d5328ab 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -202,9 +202,10 @@ class ArbitraryFunction(IntermediateNode): # The arbitrary_func is not optional but mypy has a long standing bug and is not able to # understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623 arbitrary_func: Optional[Callable] + op_name: str op_args: Tuple[Any, ...] op_kwargs: Dict[str, Any] - op_name: str + op_attributes: Dict[str, Any] _n_in: int = 1 def __init__( @@ -215,12 +216,14 @@ class ArbitraryFunction(IntermediateNode): op_name: Optional[str] = None, op_args: Optional[Tuple[Any, ...]] = None, op_kwargs: Optional[Dict[str, Any]] = None, + op_attributes: Optional[Dict[str, Any]] = None, ) -> None: super().__init__([input_base_value]) custom_assert(len(self.inputs) == 1) self.arbitrary_func = arbitrary_func self.op_args = op_args if op_args is not None else () self.op_kwargs = op_kwargs if op_kwargs is not None else {} + self.op_attributes = op_attributes if op_attributes is not None else {} output = deepcopy(input_base_value) output.dtype = output_dtype diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index d6e866bfe..1c62ace0a 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -208,6 +208,15 @@ class NPTracer(BaseTracer): op_kwargs = deepcopy(kwargs) op_kwargs["baked_constant"] = baked_constant + # Store info on the operation being treated + # Currently: the base value and type corresponding to the baked constant and which input idx + # it was feeding + op_attributes = { + "baked_constant_ir_node": deepcopy( + input_tracers[in_which_input_is_constant].traced_computation + ), + "in_which_input_is_constant": in_which_input_is_constant, + } traced_computation = ArbitraryFunction( input_base_value=input_tracers[in_which_input_is_variable].output, @@ -215,6 +224,7 @@ class NPTracer(BaseTracer): output_dtype=common_output_dtypes[0], op_kwargs=op_kwargs, op_name=binary_operator_string, + op_attributes=op_attributes, ) output_tracer = cls( (input_tracers[in_which_input_is_variable],), diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index a44234d2f..40826ddeb 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -7,6 +7,7 @@ import numpy import pytest from concrete.common.data_types.integers import Integer +from concrete.common.debugging.custom_assert import assert_not_reached from concrete.common.optimization.topological import fuse_float_operations from concrete.common.values import EncryptedScalar, EncryptedTensor from concrete.numpy import tracing @@ -134,17 +135,27 @@ def test_fuse_float_operations(function_to_trace, fused, input_): assert function_to_trace(*inputs) == op_graph(*inputs) -# TODO: #199 To be removed when doing tensor management -def test_tensor_no_fuse(): +def subtest_tensor_no_fuse(fun, tensor_shape): """Test case to verify float fusing is only applied on functions on scalars.""" - ndim = random.randint(1, 3) - tensor_shape = tuple(random.randint(1, 10) for _ in range(ndim + 1)) + if tensor_shape == (): + # We want tensors + return + + if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES: + # We need at least one input of the bivariate function to be float + return + + # Float fusing currently cannot work if the constant in a bivariate operator is bigger than the + # variable input. + # Make a broadcastable shape but with the constant being bigger + variable_tensor_shape = (1,) + tensor_shape + constant_bigger_shape = (random.randint(2, 10),) + tensor_shape def tensor_no_fuse(x): intermediate = x.astype(numpy.float64) - intermediate = intermediate.astype(numpy.int32) - return intermediate + numpy.ones(tensor_shape) + intermediate = fun(intermediate, numpy.ones(constant_bigger_shape)) + return intermediate.astype(numpy.int32) function_to_trace = tensor_no_fuse params_names = signature(function_to_trace).parameters.keys() @@ -152,7 +163,7 @@ def test_tensor_no_fuse(): op_graph = trace_numpy_function( function_to_trace, { - param_name: EncryptedTensor(Integer(32, True), shape=tensor_shape) + param_name: EncryptedTensor(Integer(32, True), shape=variable_tensor_shape) for param_name in params_names }, ) @@ -163,7 +174,24 @@ def test_tensor_no_fuse(): assert orig_num_nodes == fused_num_nodes -def subtest_fuse_float_unary_operations_correctness(fun): +def check_results_are_equal(function_result, op_graph_result): + """Check the output of function execution and OPGraph evaluation are equal.""" + + if isinstance(function_result, tuple) and isinstance(op_graph_result, tuple): + assert len(function_result) == len(op_graph_result) + are_equal = ( + function_output == op_graph_output + for function_output, op_graph_output in zip(function_result, op_graph_result) + ) + elif not isinstance(function_result, tuple) and not isinstance(op_graph_result, tuple): + are_equal = (function_result == op_graph_result,) + else: + assert_not_reached(f"Incompatible outputs: {function_result}, {op_graph_result}") + + return all(value.all() if isinstance(value, numpy.ndarray) else value for value in are_equal) + + +def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape): """Test a unary function with fuse_float_operations.""" # Some manipulation to avoid issues with domain of definitions of functions @@ -193,7 +221,10 @@ def subtest_fuse_float_unary_operations_correctness(fun): op_graph = trace_numpy_function( function_to_trace, - {param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names}, + { + param_name: EncryptedTensor(Integer(32, True), tensor_shape) + for param_name in params_names + }, ) orig_num_nodes = len(op_graph.graph) fuse_float_operations(op_graph) @@ -201,12 +232,20 @@ def subtest_fuse_float_unary_operations_correctness(fun): assert fused_num_nodes < orig_num_nodes - input_ = numpy.int32(input_) + ones_input = ( + numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_))) + if tensor_shape != () + else 1 + ) + input_ = numpy.int32(input_ * ones_input) num_params = len(params_names) inputs = (input_,) * num_params - assert function_to_trace(*inputs) == op_graph(*inputs) + function_result = function_to_trace(*inputs) + op_graph_result = op_graph(*inputs) + + assert check_results_are_equal(function_result, op_graph_result) LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = { @@ -227,7 +266,7 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = { } -def subtest_fuse_float_binary_operations_correctness(fun): +def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape): """Test a binary functions with fuse_float_operations, with a constant as a source.""" for i in range(4): @@ -248,23 +287,37 @@ def subtest_fuse_float_binary_operations_correctness(fun): # For bivariate functions: fix one of the inputs if i == 0: # With an integer in first position + ones_0 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1 + def get_function_to_trace(): - return lambda x, y: fun(3, x + y).astype(numpy.float64).astype(numpy.int32) + return lambda x, y: fun(3 * ones_0, x + y).astype(numpy.float64).astype(numpy.int32) elif i == 1: # With a float in first position + ones_1 = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1 + def get_function_to_trace(): - return lambda x, y: fun(2.3, x + y).astype(numpy.float64).astype(numpy.int32) + return ( + lambda x, y: fun(2.3 * ones_1, x + y).astype(numpy.float64).astype(numpy.int32) + ) elif i == 2: # With an integer in second position + ones_2 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1 + def get_function_to_trace(): - return lambda x, y: fun(x + y, 4).astype(numpy.float64).astype(numpy.int32) + return lambda x, y: fun(x + y, 4 * ones_2).astype(numpy.float64).astype(numpy.int32) else: # With a float in second position + ones_else = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1 + def get_function_to_trace(): - return lambda x, y: fun(x + y, 5.7).astype(numpy.float64).astype(numpy.int32) + return ( + lambda x, y: fun(x + y, 5.7 * ones_else) + .astype(numpy.float64) + .astype(numpy.int32) + ) input_list = [0, 2, 42, 44] @@ -273,6 +326,12 @@ def subtest_fuse_float_binary_operations_correctness(fun): input_list = [2, 42, 44] for input_ in input_list: + ones_input = ( + numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_))) + if tensor_shape != () + else 1 + ) + input_ = input_ * ones_input function_to_trace = get_function_to_trace() @@ -280,7 +339,10 @@ def subtest_fuse_float_binary_operations_correctness(fun): op_graph = trace_numpy_function( function_to_trace, - {param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names}, + { + param_name: EncryptedTensor(Integer(32, True), tensor_shape) + for param_name in params_names + }, ) orig_num_nodes = len(op_graph.graph) fuse_float_operations(op_graph) @@ -293,10 +355,13 @@ def subtest_fuse_float_binary_operations_correctness(fun): num_params = len(params_names) inputs = (input_,) * num_params - assert function_to_trace(*inputs) == op_graph(*inputs) + function_result = function_to_trace(*inputs) + op_graph_result = op_graph(*inputs) + + assert check_results_are_equal(function_result, op_graph_result) -def subtest_fuse_float_binary_operations_dont_support_two_variables(fun): +def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape): """Test a binary function with fuse_float_operations, with no constant as a source.""" @@ -310,18 +375,23 @@ def subtest_fuse_float_binary_operations_dont_support_two_variables(fun): with pytest.raises(NotImplementedError, match=r"Can't manage binary operator"): trace_numpy_function( function_to_trace, - {param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names}, + { + param_name: EncryptedTensor(Integer(32, True), tensor_shape) + for param_name in params_names + }, ) @pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC) -def test_ufunc_operations(fun): +@pytest.mark.parametrize("tensor_shape", [(), (3, 1, 2)]) +def test_ufunc_operations(fun, tensor_shape): """Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC.""" if fun.nin == 1: - subtest_fuse_float_unary_operations_correctness(fun) + subtest_fuse_float_unary_operations_correctness(fun, tensor_shape) elif fun.nin == 2: - subtest_fuse_float_binary_operations_correctness(fun) - subtest_fuse_float_binary_operations_dont_support_two_variables(fun) + subtest_fuse_float_binary_operations_correctness(fun, tensor_shape) + subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape) + subtest_tensor_no_fuse(fun, tensor_shape) else: raise NotImplementedError("Only unary and binary functions are tested for now")