mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: manage float fusing for tensors
- add op_attributes on ArbitraryFunction - this works if all constants are smaller than the input tensor - otherwise it requires more advanced code and a concatenate operator which currently does not exist
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
"""File holding topological optimization/simplification code."""
|
||||
import itertools
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..compilation.artifacts import CompilationArtifacts
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..debugging.custom_assert import assert_true, custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode
|
||||
from ..values import TensorValue
|
||||
@@ -39,10 +40,6 @@ def fuse_float_operations(
|
||||
float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result
|
||||
processed_terminal_nodes.add(terminal_node)
|
||||
|
||||
# TODO: #199 To be removed when doing tensor management
|
||||
if not subgraph_is_scalar_only(subgraph_all_nodes):
|
||||
continue
|
||||
|
||||
subgraph_conversion_result = convert_float_subgraph_to_fused_node(
|
||||
op_graph,
|
||||
float_subgraph_start_nodes,
|
||||
@@ -111,16 +108,20 @@ def convert_float_subgraph_to_fused_node(
|
||||
output must be plugged as the input to the subgraph.
|
||||
"""
|
||||
|
||||
if not subgraph_has_unique_variable_input(float_subgraph_start_nodes):
|
||||
subgraph_can_be_fused = subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes
|
||||
) and subgraph_values_allow_fusing(float_subgraph_start_nodes, subgraph_all_nodes)
|
||||
|
||||
if not subgraph_can_be_fused:
|
||||
return None
|
||||
|
||||
# Only one variable input node, find which node feeds its input
|
||||
non_constant_start_nodes = [
|
||||
variable_input_nodes = [
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
|
||||
]
|
||||
custom_assert(len(non_constant_start_nodes) == 1)
|
||||
custom_assert(len(variable_input_nodes) == 1)
|
||||
|
||||
current_subgraph_variable_input = non_constant_start_nodes[0]
|
||||
current_subgraph_variable_input = variable_input_nodes[0]
|
||||
new_input_value = deepcopy(current_subgraph_variable_input.outputs[0])
|
||||
|
||||
nx_graph = op_graph.graph
|
||||
@@ -244,20 +245,89 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
|
||||
|
||||
|
||||
# TODO: #199 To be removed when doing tensor management
|
||||
def subgraph_is_scalar_only(subgraph_all_nodes: Set[IntermediateNode]) -> bool:
|
||||
"""Check subgraph only processes scalars.
|
||||
def subgraph_values_allow_fusing(
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
subgraph_all_nodes: Set[IntermediateNode],
|
||||
):
|
||||
"""Check if a subgraph's values are compatible with fusing.
|
||||
|
||||
A fused subgraph for example only works on an input tensor if the resulting ArbitraryFunction
|
||||
can be applied per cell, hence shuffling or tensor shape changes make fusing impossible.
|
||||
|
||||
Args:
|
||||
subgraph_all_nodes (Set[IntermediateNode]): The nodes of the float subgraph.
|
||||
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph.
|
||||
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
|
||||
|
||||
Returns:
|
||||
bool: True if all inputs and outputs of the nodes in the subgraph are scalars.
|
||||
bool: True if all inputs and outputs of the nodes in the subgraph are compatible with fusing
|
||||
i.e. outputs have the same shapes equal to the variable input.
|
||||
"""
|
||||
return all(
|
||||
all(isinstance(input_, TensorValue) and input_.is_scalar for input_ in node.inputs)
|
||||
and all(isinstance(output, TensorValue) and output.is_scalar for output in node.outputs)
|
||||
|
||||
variable_input_nodes = [
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
|
||||
]
|
||||
|
||||
assert_true(
|
||||
(num_variable_input_nodes := len(variable_input_nodes)) == 1,
|
||||
f"{subgraph_values_allow_fusing.__name__} "
|
||||
f"only works for subgraphs with 1 variable input node, got {num_variable_input_nodes}",
|
||||
)
|
||||
|
||||
# Some ArbitraryFunction nodes have baked constants that need to be taken into account for the
|
||||
# max size computation
|
||||
baked_constants_ir_nodes = [
|
||||
baked_constant_base_value
|
||||
for node in subgraph_all_nodes
|
||||
if isinstance(node, ArbitraryFunction)
|
||||
if (baked_constant_base_value := node.op_attributes.get("baked_constant_ir_node", None))
|
||||
is not None
|
||||
]
|
||||
|
||||
all_values_are_tensors = all(
|
||||
all(isinstance(input_, TensorValue) for input_ in node.inputs)
|
||||
and all(isinstance(output, TensorValue) for output in node.outputs)
|
||||
for node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes)
|
||||
)
|
||||
|
||||
if not all_values_are_tensors:
|
||||
# This cannot be reached today as scalars are Tensors with shape == () (numpy convention)
|
||||
return False # pragma: no cover
|
||||
|
||||
variable_input_node = variable_input_nodes[0]
|
||||
|
||||
# A cheap check is that the variable input node must have the biggest size, i.e. have the most
|
||||
# elements, meaning all constants will broadcast to its shape. This is because the
|
||||
# ArbitraryFunction input and output must have the same shape so that it can be applied to each
|
||||
# of the input tensor cells.
|
||||
# There *may* be a way to manage the other case by simulating the broadcast of the smaller input
|
||||
# array and then concatenating/stacking the results. This is not currently doable as we don't
|
||||
# have a concatenate operator on the compiler side.
|
||||
# TODO: #587 https://github.com/zama-ai/concretefhe-internal/issues/587
|
||||
|
||||
variable_input_node_output = cast(TensorValue, variable_input_node.outputs[0])
|
||||
variable_input_node_output_size, variable_input_node_output_shape = (
|
||||
variable_input_node_output.size,
|
||||
variable_input_node_output.shape,
|
||||
)
|
||||
max_inputs_size = max(
|
||||
cast(TensorValue, input_node.outputs[0]).size
|
||||
for input_node in itertools.chain(subgraph_all_nodes, baked_constants_ir_nodes)
|
||||
)
|
||||
|
||||
if variable_input_node_output_size < max_inputs_size:
|
||||
return False
|
||||
|
||||
# Now that we know the variable input node has the biggest size we can check shapes are
|
||||
# consistent throughout the subgraph: outputs of ir nodes that are not constant must be equal.
|
||||
|
||||
non_constant_nodes = (node for node in subgraph_all_nodes if not isinstance(node, Constant))
|
||||
|
||||
return all(
|
||||
all(
|
||||
isinstance(output, TensorValue) and output.shape == variable_input_node_output_shape
|
||||
for output in node.outputs
|
||||
)
|
||||
for node in non_constant_nodes
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -202,9 +202,10 @@ class ArbitraryFunction(IntermediateNode):
|
||||
# The arbitrary_func is not optional but mypy has a long standing bug and is not able to
|
||||
# understand this properly. See https://github.com/python/mypy/issues/708#issuecomment-605636623
|
||||
arbitrary_func: Optional[Callable]
|
||||
op_name: str
|
||||
op_args: Tuple[Any, ...]
|
||||
op_kwargs: Dict[str, Any]
|
||||
op_name: str
|
||||
op_attributes: Dict[str, Any]
|
||||
_n_in: int = 1
|
||||
|
||||
def __init__(
|
||||
@@ -215,12 +216,14 @@ class ArbitraryFunction(IntermediateNode):
|
||||
op_name: Optional[str] = None,
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
op_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__([input_base_value])
|
||||
custom_assert(len(self.inputs) == 1)
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_args = op_args if op_args is not None else ()
|
||||
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
|
||||
self.op_attributes = op_attributes if op_attributes is not None else {}
|
||||
|
||||
output = deepcopy(input_base_value)
|
||||
output.dtype = output_dtype
|
||||
|
||||
@@ -208,6 +208,15 @@ class NPTracer(BaseTracer):
|
||||
|
||||
op_kwargs = deepcopy(kwargs)
|
||||
op_kwargs["baked_constant"] = baked_constant
|
||||
# Store info on the operation being treated
|
||||
# Currently: the base value and type corresponding to the baked constant and which input idx
|
||||
# it was feeding
|
||||
op_attributes = {
|
||||
"baked_constant_ir_node": deepcopy(
|
||||
input_tracers[in_which_input_is_constant].traced_computation
|
||||
),
|
||||
"in_which_input_is_constant": in_which_input_is_constant,
|
||||
}
|
||||
|
||||
traced_computation = ArbitraryFunction(
|
||||
input_base_value=input_tracers[in_which_input_is_variable].output,
|
||||
@@ -215,6 +224,7 @@ class NPTracer(BaseTracer):
|
||||
output_dtype=common_output_dtypes[0],
|
||||
op_kwargs=op_kwargs,
|
||||
op_name=binary_operator_string,
|
||||
op_attributes=op_attributes,
|
||||
)
|
||||
output_tracer = cls(
|
||||
(input_tracers[in_which_input_is_variable],),
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging.custom_assert import assert_not_reached
|
||||
from concrete.common.optimization.topological import fuse_float_operations
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
@@ -134,17 +135,27 @@ def test_fuse_float_operations(function_to_trace, fused, input_):
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
# TODO: #199 To be removed when doing tensor management
|
||||
def test_tensor_no_fuse():
|
||||
def subtest_tensor_no_fuse(fun, tensor_shape):
|
||||
"""Test case to verify float fusing is only applied on functions on scalars."""
|
||||
|
||||
ndim = random.randint(1, 3)
|
||||
tensor_shape = tuple(random.randint(1, 10) for _ in range(ndim + 1))
|
||||
if tensor_shape == ():
|
||||
# We want tensors
|
||||
return
|
||||
|
||||
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
|
||||
# We need at least one input of the bivariate function to be float
|
||||
return
|
||||
|
||||
# Float fusing currently cannot work if the constant in a bivariate operator is bigger than the
|
||||
# variable input.
|
||||
# Make a broadcastable shape but with the constant being bigger
|
||||
variable_tensor_shape = (1,) + tensor_shape
|
||||
constant_bigger_shape = (random.randint(2, 10),) + tensor_shape
|
||||
|
||||
def tensor_no_fuse(x):
|
||||
intermediate = x.astype(numpy.float64)
|
||||
intermediate = intermediate.astype(numpy.int32)
|
||||
return intermediate + numpy.ones(tensor_shape)
|
||||
intermediate = fun(intermediate, numpy.ones(constant_bigger_shape))
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
function_to_trace = tensor_no_fuse
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
@@ -152,7 +163,7 @@ def test_tensor_no_fuse():
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), shape=tensor_shape)
|
||||
param_name: EncryptedTensor(Integer(32, True), shape=variable_tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
@@ -163,7 +174,24 @@ def test_tensor_no_fuse():
|
||||
assert orig_num_nodes == fused_num_nodes
|
||||
|
||||
|
||||
def subtest_fuse_float_unary_operations_correctness(fun):
|
||||
def check_results_are_equal(function_result, op_graph_result):
|
||||
"""Check the output of function execution and OPGraph evaluation are equal."""
|
||||
|
||||
if isinstance(function_result, tuple) and isinstance(op_graph_result, tuple):
|
||||
assert len(function_result) == len(op_graph_result)
|
||||
are_equal = (
|
||||
function_output == op_graph_output
|
||||
for function_output, op_graph_output in zip(function_result, op_graph_result)
|
||||
)
|
||||
elif not isinstance(function_result, tuple) and not isinstance(op_graph_result, tuple):
|
||||
are_equal = (function_result == op_graph_result,)
|
||||
else:
|
||||
assert_not_reached(f"Incompatible outputs: {function_result}, {op_graph_result}")
|
||||
|
||||
return all(value.all() if isinstance(value, numpy.ndarray) else value for value in are_equal)
|
||||
|
||||
|
||||
def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
|
||||
"""Test a unary function with fuse_float_operations."""
|
||||
|
||||
# Some manipulation to avoid issues with domain of definitions of functions
|
||||
@@ -193,7 +221,10 @@ def subtest_fuse_float_unary_operations_correctness(fun):
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
@@ -201,12 +232,20 @@ def subtest_fuse_float_unary_operations_correctness(fun):
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
input_ = numpy.int32(input_)
|
||||
ones_input = (
|
||||
numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_)))
|
||||
if tensor_shape != ()
|
||||
else 1
|
||||
)
|
||||
input_ = numpy.int32(input_ * ones_input)
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
function_result = function_to_trace(*inputs)
|
||||
op_graph_result = op_graph(*inputs)
|
||||
|
||||
assert check_results_are_equal(function_result, op_graph_result)
|
||||
|
||||
|
||||
LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
|
||||
@@ -227,7 +266,7 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
|
||||
}
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
|
||||
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
|
||||
|
||||
for i in range(4):
|
||||
@@ -248,23 +287,37 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
# For bivariate functions: fix one of the inputs
|
||||
if i == 0:
|
||||
# With an integer in first position
|
||||
ones_0 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(3, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
return lambda x, y: fun(3 * ones_0, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
elif i == 1:
|
||||
# With a float in first position
|
||||
ones_1 = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(2.3, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
return (
|
||||
lambda x, y: fun(2.3 * ones_1, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
)
|
||||
|
||||
elif i == 2:
|
||||
# With an integer in second position
|
||||
ones_2 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 4).astype(numpy.float64).astype(numpy.int32)
|
||||
return lambda x, y: fun(x + y, 4 * ones_2).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
else:
|
||||
# With a float in second position
|
||||
ones_else = numpy.ones(tensor_shape, dtype=numpy.float64) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 5.7).astype(numpy.float64).astype(numpy.int32)
|
||||
return (
|
||||
lambda x, y: fun(x + y, 5.7 * ones_else)
|
||||
.astype(numpy.float64)
|
||||
.astype(numpy.int32)
|
||||
)
|
||||
|
||||
input_list = [0, 2, 42, 44]
|
||||
|
||||
@@ -273,6 +326,12 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
input_list = [2, 42, 44]
|
||||
|
||||
for input_ in input_list:
|
||||
ones_input = (
|
||||
numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_)))
|
||||
if tensor_shape != ()
|
||||
else 1
|
||||
)
|
||||
input_ = input_ * ones_input
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
@@ -280,7 +339,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
@@ -293,10 +355,13 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
function_result = function_to_trace(*inputs)
|
||||
op_graph_result = op_graph(*inputs)
|
||||
|
||||
assert check_results_are_equal(function_result, op_graph_result)
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_dont_support_two_variables(fun):
|
||||
def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape):
|
||||
"""Test a binary function with fuse_float_operations, with no constant as
|
||||
a source."""
|
||||
|
||||
@@ -310,18 +375,23 @@ def subtest_fuse_float_binary_operations_dont_support_two_variables(fun):
|
||||
with pytest.raises(NotImplementedError, match=r"Can't manage binary operator"):
|
||||
trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
|
||||
def test_ufunc_operations(fun):
|
||||
@pytest.mark.parametrize("tensor_shape", [(), (3, 1, 2)])
|
||||
def test_ufunc_operations(fun, tensor_shape):
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
|
||||
if fun.nin == 1:
|
||||
subtest_fuse_float_unary_operations_correctness(fun)
|
||||
subtest_fuse_float_unary_operations_correctness(fun, tensor_shape)
|
||||
elif fun.nin == 2:
|
||||
subtest_fuse_float_binary_operations_correctness(fun)
|
||||
subtest_fuse_float_binary_operations_dont_support_two_variables(fun)
|
||||
subtest_fuse_float_binary_operations_correctness(fun, tensor_shape)
|
||||
subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape)
|
||||
subtest_tensor_no_fuse(fun, tensor_shape)
|
||||
else:
|
||||
raise NotImplementedError("Only unary and binary functions are tested for now")
|
||||
|
||||
Reference in New Issue
Block a user