mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(float_fusing): restrict to scalars before supporting tensors
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Test file for float subgraph fusing"""
|
||||
|
||||
import random
|
||||
from inspect import signature
|
||||
|
||||
import numpy
|
||||
@@ -7,7 +8,7 @@ import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.optimization.topological import fuse_float_operations
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
from concrete.numpy.tracing import trace_numpy_function
|
||||
|
||||
@@ -116,6 +117,35 @@ def test_fuse_float_operations(function_to_trace, fused, input_):
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
# TODO: #199 To be removed when doing tensor management
|
||||
def test_tensor_no_fuse():
|
||||
"""Test case to verify float fusing is only applied on functions on scalars."""
|
||||
|
||||
ndim = random.randint(1, 3)
|
||||
tensor_shape = tuple(random.randint(1, 10) for _ in range(ndim + 1))
|
||||
|
||||
def tensor_no_fuse(x):
|
||||
intermediate = x.astype(numpy.float64)
|
||||
intermediate = intermediate.astype(numpy.int32)
|
||||
return intermediate + numpy.ones(tensor_shape)
|
||||
|
||||
function_to_trace = tensor_no_fuse
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{
|
||||
param_name: EncryptedTensor(Integer(32, True), shape=tensor_shape)
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
|
||||
assert orig_num_nodes == fused_num_nodes
|
||||
|
||||
|
||||
def test_fuse_float_operations_correctness():
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
|
||||
with fuse_float_operations."""
|
||||
|
||||
Reference in New Issue
Block a user