feat: manage float fusing for tensors

- add op_attributes on ArbitraryFunction
- this works if all constants are smaller than the input tensor
- otherwise it requires more advanced code and a concatenate operator which
currently does not exist
This commit is contained in:
Arthur Meyre
2021-10-08 09:59:56 +02:00
parent fb9cc79128
commit 44016cc80c
4 changed files with 196 additions and 43 deletions

View File

@@ -1,13 +1,14 @@
"""File holding topological optimization/simplification code."""
import itertools
from copy import deepcopy
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, cast
import networkx as nx
from ..compilation.artifacts import CompilationArtifacts
from ..data_types.floats import Float
from ..data_types.integers import Integer
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true, custom_assert
from ..operator_graph import OPGraph
from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode
from ..values import TensorValue
@@ -39,10 +40,6 @@ def fuse_float_operations(
float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result
processed_terminal_nodes.add(terminal_node)
# TODO: #199 To be removed when doing tensor management
if not subgraph_is_scalar_only(subgraph_all_nodes):
continue
subgraph_conversion_result = convert_float_subgraph_to_fused_node(
op_graph,
float_subgraph_start_nodes,
@@ -111,16 +108,20 @@ def convert_float_subgraph_to_fused_node(
output must be plugged as the input to the subgraph.
"""
if not subgraph_has_unique_variable_input(float_subgraph_start_nodes):
subgraph_can_be_fused = subgraph_has_unique_variable_input(
float_subgraph_start_nodes
) and subgraph_values_allow_fusing(float_subgraph_start_nodes, subgraph_all_nodes)
if not subgraph_can_be_fused:
return None
# Only one variable input node, find which node feeds its input
non_constant_start_nodes = [
variable_input_nodes = [
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
]
custom_assert(len(non_constant_start_nodes) == 1)
custom_assert(len(variable_input_nodes) == 1)
current_subgraph_variable_input = non_constant_start_nodes[0]
current_subgraph_variable_input = variable_input_nodes[0]
new_input_value = deepcopy(current_subgraph_variable_input.outputs[0])
nx_graph = op_graph.graph
@@ -244,20 +245,89 @@ def find_float_subgraph_with_unique_terminal_node(
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
# TODO: #199 To be removed when doing tensor management
def subgraph_is_scalar_only(subgraph_all_nodes: Set[IntermediateNode]) -> bool:
"""Check subgraph only processes scalars.
def subgraph_values_allow_fusing(
float_subgraph_start_nodes: Set[IntermediateNode],
subgraph_all_nodes: Set[IntermediateNode],
):
"""Check if a subgraph's values are compatible with fusing.
A fused subgraph for example only works on an input tensor if the resulting ArbitraryFunction
can be applied per cell, hence shuffling or tensor shape changes make fusing impossible.
Args:
subgraph_all_nodes (Set[IntermediateNode]): The nodes of the float subgraph.
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph.
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
Returns:
bool: True if all inputs and outputs of the nodes in the subgraph are scalars.
bool: True if all inputs and outputs of the nodes in the subgraph are compatible with fusing
i.e. outputs have the same shapes equal to the variable input.
"""
return all(
all(isinstance(input_, TensorValue) and input_.is_scalar for input_ in node.inputs)
and all(isinstance(output, TensorValue) and output.is_scalar for output in node.outputs)
variable_input_nodes = [
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
]
assert_true(
(num_variable_input_nodes := len(variable_input_nodes)) == 1,
f"{subgraph_values_allow_fusing.__name__} "
f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}",
)
# Some ArbitraryFunction nodes have baked constants that need to be taken into account for the
# max size computation
baked_constants_ir_nodes = [
baked_constant_base_value
for node in subgraph_all_nodes
if isinstance(node, ArbitraryFunction)
if (baked_constant_base_value := node.op_attributes.get("baked_constant_ir_node", None))
is not None
]
all_values_are_tensors = all(
all(isinstance(input_, TensorValue) for input_ in node.inputs)
and all(isinstance(output, TensorValue) for output in node.outputs)
for node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes)
)
if not all_values_are_tensors:
# This cannot be reached today as scalars are Tensors with shape == () (numpy convention)
return False # pragma: no cover
variable_input_node = variable_input_nodes[0]
# A cheap check is that the variable input node must have the biggest size, i.e. have the most
# elements, meaning all constants will broadcast to its shape. This is because the
# ArbitraryFunction input and output must have the same shape so that it can be applied to each
# of the input tensor cells.
# There *may* be a way to manage the other case by simulating the broadcast of the smaller input
# array and then concatenating/stacking the results. This is not currently doable as we don't
# have a concatenate operator on the compiler side.
# TODO: #587 https://github.com/zama-ai/concretefhe-internal/issues/587
variable_input_node_output = cast(TensorValue, variable_input_node.outputs[0])
variable_input_node_output_size, variable_input_node_output_shape = (
variable_input_node_output.size,
variable_input_node_output.shape,
)
max_inputs_size = max(
cast(TensorValue, input_node.outputs[0]).size
for input_node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes)
)
if variable_input_node_output_size < max_inputs_size:
return False
# Now that we know the variable input node has the biggest size we can check shapes are
# consistent throughout the subgraph: outputs of ir nodes that are not constant must be equal.
non_constant_nodes = (node for node in subgraph_all_nodes if not isinstance(node, Constant))
return all(
all(
isinstance(output, TensorValue) and output.shape == variable_input_node_output_shape
for output in node.outputs
)
for node in non_constant_nodes
)

View File

@@ -202,9 +202,10 @@ class ArbitraryFunction(IntermediateNode):
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
arbitrary_func: Optional[Callable]
op_name: str
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
op_name: str
op_attributes: Dict[str, Any]
_n_in: int = 1
def __init__(
@@ -215,12 +216,14 @@ class ArbitraryFunction(IntermediateNode):
op_name: Optional[str] = None,
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
op_attributes: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__([input_base_value])
custom_assert(len(self.inputs) == 1)
self.arbitrary_func = arbitrary_func
self.op_args = op_args if op_args is not None else ()
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
self.op_attributes = op_attributes if op_attributes is not None else {}
output = deepcopy(input_base_value)
output.dtype = output_dtype

View File

@@ -208,6 +208,15 @@ class NPTracer(BaseTracer):
op_kwargs = deepcopy(kwargs)
op_kwargs["baked_constant"] = baked_constant
# Store info on the operation being treated
# Currently: the base value and type corresponding to the baked constant and which input idx
# it was feeding
op_attributes = {
"baked_constant_ir_node": deepcopy(
input_tracers[in_which_input_is_constant].traced_computation
),
"in_which_input_is_constant": in_which_input_is_constant,
}
traced_computation = ArbitraryFunction(
input_base_value=input_tracers[in_which_input_is_variable].output,
@@ -215,6 +224,7 @@ class NPTracer(BaseTracer):
output_dtype=common_output_dtypes[0],
op_kwargs=op_kwargs,
op_name=binary_operator_string,
op_attributes=op_attributes,
)
output_tracer = cls(
(input_tracers[in_which_input_is_variable],),

View File

@@ -7,6 +7,7 @@ import numpy
import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.debugging.custom_assert import assert_not_reached
from concrete.common.optimization.topological import fuse_float_operations
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
@@ -134,17 +135,27 @@ def test_fuse_float_operations(function_to_trace, fused, input_):
assert function_to_trace(*inputs) == op_graph(*inputs)
# TODO: #199 To be removed when doing tensor management
def test_tensor_no_fuse():
def subtest_tensor_no_fuse(fun, tensor_shape):
"""Test case to verify float fusing is only applied on functions on scalars."""
ndim = random.randint(1, 3)
tensor_shape = tuple(random.randint(1, 10) for _ in range(ndim + 1))
if tensor_shape == ():
# We want tensors
return
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
# We need at least one input of the bivariate function to be float
return
# Float fusing currently cannot work if the constant in a bivariate operator is bigger than the
# variable input.
# Make a broadcastable shape but with the constant being bigger
variable_tensor_shape = (1,) + tensor_shape
constant_bigger_shape = (random.randint(2, 10),) + tensor_shape
def tensor_no_fuse(x):
intermediate = x.astype(numpy.float64)
intermediate = intermediate.astype(numpy.int32)
return intermediate + numpy.ones(tensor_shape)
intermediate = fun(intermediate, numpy.ones(constant_bigger_shape))
return intermediate.astype(numpy.int32)
function_to_trace = tensor_no_fuse
params_names = signature(function_to_trace).parameters.keys()
@@ -152,7 +163,7 @@ def test_tensor_no_fuse():
op_graph = trace_numpy_function(
function_to_trace,
{
param_name: EncryptedTensor(Integer(32, True), shape=tensor_shape)
param_name: EncryptedTensor(Integer(32, True), shape=variable_tensor_shape)
for param_name in params_names
},
)
@@ -163,7 +174,24 @@ def test_tensor_no_fuse():
assert orig_num_nodes == fused_num_nodes
def subtest_fuse_float_unary_operations_correctness(fun):
def check_results_are_equal(function_result, op_graph_result):
"""Check the output of function execution and OPGraph evaluation are equal."""
if isinstance(function_result, tuple) and isinstance(op_graph_result, tuple):
assert len(function_result) == len(op_graph_result)
are_equal = (
function_output == op_graph_output
for function_output, op_graph_output in zip(function_result, op_graph_result)
)
elif not isinstance(function_result, tuple) and not isinstance(op_graph_result, tuple):
are_equal = (function_result == op_graph_result,)
else:
assert_not_reached(f"Incompatible outputs: {function_result}, {op_graph_result}")
return all(value.all() if isinstance(value, numpy.ndarray) else value for value in are_equal)
def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
"""Test a unary function with fuse_float_operations."""
# Some manipulation to avoid issues with domain of definitions of functions
@@ -193,7 +221,10 @@ def subtest_fuse_float_unary_operations_correctness(fun):
op_graph = trace_numpy_function(
function_to_trace,
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
{
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
for param_name in params_names
},
)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
@@ -201,12 +232,20 @@ def subtest_fuse_float_unary_operations_correctness(fun):
assert fused_num_nodes < orig_num_nodes
input_ = numpy.int32(input_)
ones_input = (
numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_)))
if tensor_shape != ()
else 1
)
input_ = numpy.int32(input_ * ones_input)
num_params = len(params_names)
inputs = (input_,) * num_params
assert function_to_trace(*inputs) == op_graph(*inputs)
function_result = function_to_trace(*inputs)
op_graph_result = op_graph(*inputs)
assert check_results_are_equal(function_result, op_graph_result)
LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
@@ -227,7 +266,7 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
}
def subtest_fuse_float_binary_operations_correctness(fun):
def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
for i in range(4):
@@ -248,23 +287,37 @@ def subtest_fuse_float_binary_operations_correctness(fun):
# For bivariate functions: fix one of the inputs
if i == 0:
# With an integer in first position
ones_0 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1
def get_function_to_trace():
return lambda x, y: fun(3, x + y).astype(numpy.float64).astype(numpy.int32)
return lambda x, y: fun(3 * ones_0, x + y).astype(numpy.float64).astype(numpy.int32)
elif i == 1:
# With a float in first position
ones_1 = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1
def get_function_to_trace():
return lambda x, y: fun(2.3, x + y).astype(numpy.float64).astype(numpy.int32)
return (
lambda x, y: fun(2.3 * ones_1, x + y).astype(numpy.float64).astype(numpy.int32)
)
elif i == 2:
# With an integer in second position
ones_2 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1
def get_function_to_trace():
return lambda x, y: fun(x + y, 4).astype(numpy.float64).astype(numpy.int32)
return lambda x, y: fun(x + y, 4 * ones_2).astype(numpy.float64).astype(numpy.int32)
else:
# With a float in second position
ones_else = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1
def get_function_to_trace():
return lambda x, y: fun(x + y, 5.7).astype(numpy.float64).astype(numpy.int32)
return (
lambda x, y: fun(x + y, 5.7 * ones_else)
.astype(numpy.float64)
.astype(numpy.int32)
)
input_list = [0, 2, 42, 44]
@@ -273,6 +326,12 @@ def subtest_fuse_float_binary_operations_correctness(fun):
input_list = [2, 42, 44]
for input_ in input_list:
ones_input = (
numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_)))
if tensor_shape != ()
else 1
)
input_ = input_ * ones_input
function_to_trace = get_function_to_trace()
@@ -280,7 +339,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
op_graph = trace_numpy_function(
function_to_trace,
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
{
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
for param_name in params_names
},
)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
@@ -293,10 +355,13 @@ def subtest_fuse_float_binary_operations_correctness(fun):
num_params = len(params_names)
inputs = (input_,) * num_params
assert function_to_trace(*inputs) == op_graph(*inputs)
function_result = function_to_trace(*inputs)
op_graph_result = op_graph(*inputs)
assert check_results_are_equal(function_result, op_graph_result)
def subtest_fuse_float_binary_operations_dont_support_two_variables(fun):
def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape):
"""Test a binary function with fuse_float_operations, with no constant as
a source."""
@@ -310,18 +375,23 @@ def subtest_fuse_float_binary_operations_dont_support_two_variables(fun):
with pytest.raises(NotImplementedError, match=r"Can't manage binary operator"):
trace_numpy_function(
function_to_trace,
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
{
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
for param_name in params_names
},
)
@pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
def test_ufunc_operations(fun):
@pytest.mark.parametrize("tensor_shape", [(), (3, 1, 2)])
def test_ufunc_operations(fun, tensor_shape):
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
if fun.nin == 1:
subtest_fuse_float_unary_operations_correctness(fun)
subtest_fuse_float_unary_operations_correctness(fun, tensor_shape)
elif fun.nin == 2:
subtest_fuse_float_binary_operations_correctness(fun)
subtest_fuse_float_binary_operations_dont_support_two_variables(fun)
subtest_fuse_float_binary_operations_correctness(fun, tensor_shape)
subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape)
subtest_tensor_no_fuse(fun, tensor_shape)
else:
raise NotImplementedError("Only unary and binary functions are tested for now")