mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
committed by
Benoit Chevallier
parent
83ea485fe1
commit
56e0ed4a11
@@ -146,45 +146,127 @@ def test_tensor_no_fuse():
|
||||
assert orig_num_nodes == fused_num_nodes
|
||||
|
||||
|
||||
def test_fuse_float_operations_correctness():
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
|
||||
with fuse_float_operations."""
|
||||
def subtest_fuse_float_unary_operations_correctness(fun):
|
||||
"""Test a unary function with fuse_float_operations."""
|
||||
|
||||
for fun in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
|
||||
# Some manipulation to avoid issues with domain of definitions of functions
|
||||
if fun == numpy.arccosh:
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
input_list = [0, 0.1, 0.2]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
else:
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input]
|
||||
|
||||
for super_fun in super_fun_list:
|
||||
|
||||
for input_ in input_list:
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: super_fun(fun, x, y)
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
input_ = numpy.int32(input_)
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
|
||||
|
||||
for i in range(4):
|
||||
|
||||
# For bivariate functions: fix one of the inputs
|
||||
if i == 0:
|
||||
# With an integer in first position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(3, x + y).astype(numpy.int32)
|
||||
|
||||
elif i == 1:
|
||||
# With a float in first position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(2.3, x + y).astype(numpy.int32)
|
||||
|
||||
elif i == 2:
|
||||
# With an integer in second position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 4).astype(numpy.int32)
|
||||
|
||||
if fun == numpy.arccosh:
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
input_list = [0, 0.1, 0.2]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
else:
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input]
|
||||
# With a float in second position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 5.7).astype(numpy.int32)
|
||||
|
||||
for super_fun in super_fun_list:
|
||||
input_list = [0, 2, 42, 44]
|
||||
|
||||
for input_ in input_list:
|
||||
for input_ in input_list:
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: super_fun(fun, x, y)
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
|
||||
op_graph = trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
input_ = numpy.int32(input_)
|
||||
|
||||
input_ = numpy.int32(input_)
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_dont_support_two_variables(fun):
|
||||
"""Test a binary function with fuse_float_operations, with no constant as
|
||||
a source."""
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x, y).astype(numpy.int32)
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
|
||||
with pytest.raises(NotImplementedError, match=r"Can't manage binary operator"):
|
||||
trace_numpy_function(
|
||||
function_to_trace,
|
||||
{param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
|
||||
def test_ufunc_operations(fun):
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
|
||||
if fun.nin == 1:
|
||||
subtest_fuse_float_unary_operations_correctness(fun)
|
||||
elif fun.nin == 2:
|
||||
subtest_fuse_float_binary_operations_correctness(fun)
|
||||
subtest_fuse_float_binary_operations_dont_support_two_variables(fun)
|
||||
else:
|
||||
raise NotImplementedError("Only unary and binary functions are tested for now")
|
||||
|
||||
Reference in New Issue
Block a user