feat(float_fusing): restrict to scalars before supporting tensors

This commit is contained in:
Arthur Meyre
2021-09-30 10:00:37 +02:00
parent 42d5b66b69
commit a9d44f4758
2 changed files with 53 additions and 1 deletions

View File

@@ -10,6 +10,7 @@ from ..data_types.integers import Integer
from ..debugging.custom_assert import custom_assert
from ..operator_graph import OPGraph
from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode
from ..values import TensorValue
def fuse_float_operations(
@@ -38,6 +39,10 @@ def fuse_float_operations(
float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result
processed_terminal_nodes.add(terminal_node)
# TODO: #199 To be removed when doing tensor management
if not subgraph_is_scalar_only(subgraph_all_nodes):
continue
subgraph_conversion_result = convert_float_subgraph_to_fused_node(
op_graph,
float_subgraph_start_nodes,
@@ -239,6 +244,23 @@ def find_float_subgraph_with_unique_terminal_node(
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
# TODO: #199 To be removed when doing tensor management
def subgraph_is_scalar_only(subgraph_all_nodes: Set[IntermediateNode]) -> bool:
"""Check subgraph only processes scalars.
Args:
subgraph_all_nodes (Set[IntermediateNode]): The nodes of the float subgraph.
Returns:
bool: True if all inputs and outputs of the nodes in the subgraph are scalars.
"""
return all(
all(isinstance(input_, TensorValue) and input_.is_scalar for input_ in node.inputs)
and all(isinstance(output, TensorValue) and output.is_scalar for output in node.outputs)
for node in subgraph_all_nodes
)
def subgraph_has_unique_variable_input(
float_subgraph_start_nodes: Set[IntermediateNode],
) -> bool:

View File

@@ -1,5 +1,6 @@
"""Test file for float subgraph fusing"""
import random
from inspect import signature
import numpy
@@ -7,7 +8,7 @@ import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.optimization.topological import fuse_float_operations
from concrete.common.values import EncryptedScalar
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
from concrete.numpy.tracing import trace_numpy_function
@@ -116,6 +117,35 @@ def test_fuse_float_operations(function_to_trace, fused, input_):
assert function_to_trace(*inputs) == op_graph(*inputs)
# TODO: #199 To be removed when doing tensor management
def test_tensor_no_fuse():
"""Test case to verify float fusing is only applied on functions on scalars."""
ndim = random.randint(1, 3)
tensor_shape = tuple(random.randint(1, 10) for _ in range(ndim + 1))
def tensor_no_fuse(x):
intermediate = x.astype(numpy.float64)
intermediate = intermediate.astype(numpy.int32)
return intermediate + numpy.ones(tensor_shape)
function_to_trace = tensor_no_fuse
params_names = signature(function_to_trace).parameters.keys()
op_graph = trace_numpy_function(
function_to_trace,
{
param_name: EncryptedTensor(Integer(32, True), shape=tensor_shape)
for param_name in params_names
},
)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
fused_num_nodes = len(op_graph.graph)
assert orig_num_nodes == fused_num_nodes
def test_fuse_float_operations_correctness():
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
with fuse_float_operations."""