From a9d44f4758b61e616b2df31865d239fb045dff12 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 30 Sep 2021 10:00:37 +0200 Subject: [PATCH] feat(float_fusing): restrict to scalars before supporting tensors --- concrete/common/optimization/topological.py | 22 +++++++++++++ .../common/optimization/test_float_fusing.py | 32 ++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index ab62a940a..77bd5f947 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -10,6 +10,7 @@ from ..data_types.integers import Integer from ..debugging.custom_assert import custom_assert from ..operator_graph import OPGraph from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode +from ..values import TensorValue def fuse_float_operations( @@ -38,6 +39,10 @@ def fuse_float_operations( float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result processed_terminal_nodes.add(terminal_node) + # TODO: #199 To be removed when doing tensor management + if not subgraph_is_scalar_only(subgraph_all_nodes): + continue + subgraph_conversion_result = convert_float_subgraph_to_fused_node( op_graph, float_subgraph_start_nodes, @@ -239,6 +244,23 @@ def find_float_subgraph_with_unique_terminal_node( return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes +# TODO: #199 To be removed when doing tensor management +def subgraph_is_scalar_only(subgraph_all_nodes: Set[IntermediateNode]) -> bool: + """Check subgraph only processes scalars. + + Args: + subgraph_all_nodes (Set[IntermediateNode]): The nodes of the float subgraph. + + Returns: + bool: True if all inputs and outputs of the nodes in the subgraph are scalars. + """ + return all( + all(isinstance(input_, TensorValue) and input_.is_scalar for input_ in node.inputs) + and all(isinstance(output, TensorValue) and output.is_scalar for output in node.outputs) + for node in subgraph_all_nodes + ) + + def subgraph_has_unique_variable_input( float_subgraph_start_nodes: Set[IntermediateNode], ) -> bool: diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 57cef96d3..f10535aab 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -1,5 +1,6 @@ """Test file for float subgraph fusing""" +import random from inspect import signature import numpy @@ -7,7 +8,7 @@ import pytest from concrete.common.data_types.integers import Integer from concrete.common.optimization.topological import fuse_float_operations -from concrete.common.values import EncryptedScalar +from concrete.common.values import EncryptedScalar, EncryptedTensor from concrete.numpy import tracing from concrete.numpy.tracing import trace_numpy_function @@ -116,6 +117,35 @@ def test_fuse_float_operations(function_to_trace, fused, input_): assert function_to_trace(*inputs) == op_graph(*inputs) +# TODO: #199 To be removed when doing tensor management +def test_tensor_no_fuse(): + """Test case to verify float fusing is only applied on functions on scalars.""" + + ndim = random.randint(1, 3) + tensor_shape = tuple(random.randint(1, 10) for _ in range(ndim + 1)) + + def tensor_no_fuse(x): + intermediate = x.astype(numpy.float64) + intermediate = intermediate.astype(numpy.int32) + return intermediate + numpy.ones(tensor_shape) + + function_to_trace = tensor_no_fuse + params_names = signature(function_to_trace).parameters.keys() + + op_graph = trace_numpy_function( + function_to_trace, + { + param_name: EncryptedTensor(Integer(32, True), shape=tensor_shape) + for param_name in params_names + }, + ) + orig_num_nodes = len(op_graph.graph) + fuse_float_operations(op_graph) + fused_num_nodes = len(op_graph.graph) + + assert orig_num_nodes == fused_num_nodes + + def test_fuse_float_operations_correctness(): """Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC with fuse_float_operations."""