mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(float_fusing): restrict to scalars before supporting tensors
This commit is contained in:
@@ -10,6 +10,7 @@ from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode
|
||||
from ..values import TensorValue
|
||||
|
||||
|
||||
def fuse_float_operations(
|
||||
@@ -38,6 +39,10 @@ def fuse_float_operations(
|
||||
float_subgraph_start_nodes, terminal_node, subgraph_all_nodes = float_subgraph_search_result
|
||||
processed_terminal_nodes.add(terminal_node)
|
||||
|
||||
# TODO: #199 To be removed when doing tensor management
|
||||
if not subgraph_is_scalar_only(subgraph_all_nodes):
|
||||
continue
|
||||
|
||||
subgraph_conversion_result = convert_float_subgraph_to_fused_node(
|
||||
op_graph,
|
||||
float_subgraph_start_nodes,
|
||||
@@ -239,6 +244,23 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
|
||||
|
||||
|
||||
# TODO: #199 To be removed when doing tensor management
|
||||
def subgraph_is_scalar_only(subgraph_all_nodes: Set[IntermediateNode]) -> bool:
|
||||
"""Check subgraph only processes scalars.
|
||||
|
||||
Args:
|
||||
subgraph_all_nodes (Set[IntermediateNode]): The nodes of the float subgraph.
|
||||
|
||||
Returns:
|
||||
bool: True if all inputs and outputs of the nodes in the subgraph are scalars.
|
||||
"""
|
||||
return all(
|
||||
all(isinstance(input_, TensorValue) and input_.is_scalar for input_ in node.inputs)
|
||||
and all(isinstance(output, TensorValue) and output.is_scalar for output in node.outputs)
|
||||
for node in subgraph_all_nodes
|
||||
)
|
||||
|
||||
|
||||
def subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
) -> bool:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Test file for float subgraph fusing"""
|
||||
|
||||
import random
|
||||
from inspect import signature
|
||||
|
||||
import numpy
|
||||
@@ -7,7 +8,7 @@ import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.optimization.topological import fuse_float_operations
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
from concrete.numpy.tracing import trace_numpy_function
|
||||
|
||||
@@ -116,6 +117,35 @@ def test_fuse_float_operations(function_to_trace, fused, input_):
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
# TODO: #199 To be removed when doing tensor management
|
||||
def test_tensor_no_fuse():
|
||||
"""Test case to verify float fusing is only applied on functions on scalars."""
|
||||
|
||||
ndim = random.randint(1, 3)
|
||||
tensor_shape = tuple(random.randint(1, 10) for _ in range(ndim + 1))
|
||||
|
||||
def tensor_no_fuse(x):
|
||||
intermediate = x.astype(numpy.float64)
|
||||
intermediate = intermediate.astype(numpy.int32)
|
||||
return intermediate + numpy.ones(tensor_shape)
|
||||
|
||||
function_to_trace = tensor_no_fuse
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), shape=tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
|
||||
assert orig_num_nodes == fused_num_nodes
|
||||
|
||||
|
||||
def test_fuse_float_operations_correctness():
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
|
||||
with fuse_float_operations."""
|
||||
|
||||
Reference in New Issue
Block a user