mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
test(ci): more random inputs
more random-looking inputs in subtest_fuse_float_binary_operations_correctness and subtest_fuse_float_unary_operations_correctness closes #547
This commit is contained in:
committed by
Benoit Chevallier
parent
17704da169
commit
6e79c0baf5
@@ -196,15 +196,23 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
|
||||
|
||||
# Some manipulation to avoid issues with domain of definitions of functions
|
||||
if fun == numpy.arccosh:
|
||||
# 0 is not in the domain of definition
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [mix_x_and_y_and_call_f]
|
||||
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
# Needs values between 0 and 1
|
||||
input_list = [0, 0.1, 0.2]
|
||||
super_fun_list = [mix_x_and_y_and_call_f]
|
||||
elif fun in [numpy.cosh, numpy.sinh, numpy.exp, numpy.exp2, numpy.expm1]:
|
||||
# Not too large values to avoid overflows
|
||||
input_list = [1, 2, 5, 11]
|
||||
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
|
||||
elif fun == numpy.invert:
|
||||
# 0 is not in the domain of definition + expect integer inputs
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [mix_x_and_y_into_integer_and_call_f]
|
||||
else:
|
||||
# Regular case
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
|
||||
|
||||
@@ -232,20 +240,53 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
ones_input = (
|
||||
numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_)))
|
||||
# Check that the call to the function or to the op_graph evaluation give the same
|
||||
# result
|
||||
tensor_diversifier = (
|
||||
# The following +1 in the range is to avoid to have 0's which is not in the
|
||||
# domain definition of some of our functions
|
||||
numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape(
|
||||
tensor_shape
|
||||
)
|
||||
if tensor_shape != ()
|
||||
else 1
|
||||
)
|
||||
input_ = numpy.int32(input_ * ones_input)
|
||||
|
||||
if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
# Domain of definition for these functions
|
||||
tensor_diversifier = (
|
||||
numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1
|
||||
)
|
||||
|
||||
input_ = numpy.int32(input_ * tensor_diversifier)
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
assert num_params == 2
|
||||
|
||||
function_result = function_to_trace(*inputs)
|
||||
op_graph_result = op_graph(*inputs)
|
||||
# Create inputs which are either of the form [x, x] or [x, y]
|
||||
for j in range(4):
|
||||
|
||||
assert check_results_are_equal(function_result, op_graph_result)
|
||||
if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
if j > 0:
|
||||
# Domain of definition for these functions
|
||||
break
|
||||
|
||||
input_a = input_
|
||||
input_b = input_ + j
|
||||
|
||||
if tensor_shape != ():
|
||||
numpy.random.shuffle(input_a)
|
||||
numpy.random.shuffle(input_b)
|
||||
|
||||
if random.randint(0, 1) == 0:
|
||||
inputs = (input_a, input_b)
|
||||
else:
|
||||
inputs = (input_b, input_a)
|
||||
|
||||
function_result = function_to_trace(*inputs)
|
||||
op_graph_result = op_graph(*inputs)
|
||||
|
||||
assert check_results_are_equal(function_result, op_graph_result)
|
||||
|
||||
|
||||
LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
|
||||
@@ -287,7 +328,7 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
|
||||
# For bivariate functions: fix one of the inputs
|
||||
if i == 0:
|
||||
# With an integer in first position
|
||||
ones_0 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1
|
||||
ones_0 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(3 * ones_0, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
@@ -303,7 +344,7 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
|
||||
|
||||
elif i == 2:
|
||||
# With an integer in second position
|
||||
ones_2 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1
|
||||
ones_2 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1
|
||||
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 4 * ones_2).astype(numpy.float64).astype(numpy.int32)
|
||||
@@ -326,13 +367,6 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
|
||||
input_list = [2, 42, 44]
|
||||
|
||||
for input_ in input_list:
|
||||
ones_input = (
|
||||
numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_)))
|
||||
if tensor_shape != ()
|
||||
else 1
|
||||
)
|
||||
input_ = input_ * ones_input
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
params_names = signature(function_to_trace).parameters.keys()
|
||||
@@ -350,15 +384,30 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
input_ = numpy.int32(input_)
|
||||
# Check that the call to the function or to the op_graph evaluation give the same
|
||||
# result
|
||||
tensor_diversifier = (
|
||||
# The following +1 in the range is to avoid to have 0's which is not in the
|
||||
# domain definition of some of our functions
|
||||
numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape(
|
||||
tensor_shape
|
||||
)
|
||||
if tensor_shape != ()
|
||||
else 1
|
||||
)
|
||||
input_ = input_ * tensor_diversifier
|
||||
|
||||
num_params = len(params_names)
|
||||
inputs = (input_,) * num_params
|
||||
assert num_params == 2
|
||||
|
||||
function_result = function_to_trace(*inputs)
|
||||
op_graph_result = op_graph(*inputs)
|
||||
# Create inputs which are either of the form [x, x] or [x, y]
|
||||
for j in range(4):
|
||||
inputs = (input_, input_ + j)
|
||||
|
||||
assert check_results_are_equal(function_result, op_graph_result)
|
||||
function_result = function_to_trace(*inputs)
|
||||
op_graph_result = op_graph(*inputs)
|
||||
|
||||
assert check_results_are_equal(function_result, op_graph_result)
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape):
|
||||
@@ -383,7 +432,9 @@ def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
|
||||
@pytest.mark.parametrize("tensor_shape", [(), (3, 1, 2)])
|
||||
@pytest.mark.parametrize(
|
||||
"tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")]
|
||||
)
|
||||
def test_ufunc_operations(fun, tensor_shape):
|
||||
"""Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user