feat(float_fusing): restrict to scalars before supporting tensors

This commit is contained in:
Arthur Meyre
2021-09-30 10:00:37 +02:00
parent 42d5b66b69
commit a9d44f4758
2 changed files with 53 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
"""Test file for float subgraph fusing"""
import random
from inspect import signature
import numpy
@@ -7,7 +8,7 @@ import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.optimization.topological import fuse_float_operations
from concrete.common.values import EncryptedScalar
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
from concrete.numpy.tracing import trace_numpy_function
@@ -116,6 +117,35 @@ def test_fuse_float_operations(function_to_trace, fused, input_):
assert function_to_trace(*inputs) == op_graph(*inputs)
# TODO: #199 To be removed when doing tensor management
def test_tensor_no_fuse():
"""Test case to verify float fusing is only applied on functions on scalars."""
ndim = random.randint(1, 3)
tensor_shape = tuple(random.randint(1, 10) for _ in range(ndim + 1))
def tensor_no_fuse(x):
intermediate = x.astype(numpy.float64)
intermediate = intermediate.astype(numpy.int32)
return intermediate + numpy.ones(tensor_shape)
function_to_trace = tensor_no_fuse
params_names = signature(function_to_trace).parameters.keys()
op_graph = trace_numpy_function(
function_to_trace,
{
param_name: EncryptedTensor(Integer(32, True), shape=tensor_shape)
for param_name in params_names
},
)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
fused_num_nodes = len(op_graph.graph)
assert orig_num_nodes == fused_num_nodes
def test_fuse_float_operations_correctness():
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
with fuse_float_operations."""