diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 40826ddeb..6118fc178 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -196,15 +196,23 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape): # Some manipulation to avoid issues with domain of definitions of functions if fun == numpy.arccosh: + # 0 is not in the domain of definition input_list = [1, 2, 42, 44] super_fun_list = [mix_x_and_y_and_call_f] elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]: + # Needs values between 0 and 1 input_list = [0, 0.1, 0.2] super_fun_list = [mix_x_and_y_and_call_f] + elif fun in [numpy.cosh, numpy.sinh, numpy.exp, numpy.exp2, numpy.expm1]: + # Not too large values to avoid overflows + input_list = [1, 2, 5, 11] + super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f] elif fun == numpy.invert: + # 0 is not in the domain of definition + expect integer inputs input_list = [1, 2, 42, 44] super_fun_list = [mix_x_and_y_into_integer_and_call_f] else: + # Regular case input_list = [0, 2, 42, 44] super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f] @@ -232,20 +240,53 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape): assert fused_num_nodes < orig_num_nodes - ones_input = ( - numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_))) + # Check that the call to the function or to the op_graph evaluation give the same + # result + tensor_diversifier = ( + # The following +1 in the range is to avoid to have 0's which is not in the + # domain definition of some of our functions + numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape( + tensor_shape + ) if tensor_shape != () else 1 ) - input_ = numpy.int32(input_ * ones_input) + + if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]: + # Domain of definition for these functions + tensor_diversifier = ( + numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1 + ) + + input_ = numpy.int32(input_ * tensor_diversifier) num_params = len(params_names) - inputs = (input_,) * num_params + assert num_params == 2 - function_result = function_to_trace(*inputs) - op_graph_result = op_graph(*inputs) + # Create inputs which are either of the form [x, x] or [x, y] + for j in range(4): - assert check_results_are_equal(function_result, op_graph_result) + if fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]: + if j > 0: + # Domain of definition for these functions + break + + input_a = input_ + input_b = input_ + j + + if tensor_shape != (): + numpy.random.shuffle(input_a) + numpy.random.shuffle(input_b) + + if random.randint(0, 1) == 0: + inputs = (input_a, input_b) + else: + inputs = (input_b, input_a) + + function_result = function_to_trace(*inputs) + op_graph_result = op_graph(*inputs) + + assert check_results_are_equal(function_result, op_graph_result) LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = { @@ -287,7 +328,7 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape): # For bivariate functions: fix one of the inputs if i == 0: # With an integer in first position - ones_0 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1 + ones_0 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1 def get_function_to_trace(): return lambda x, y: fun(3 * ones_0, x + y).astype(numpy.float64).astype(numpy.int32) @@ -303,7 +344,7 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape): elif i == 2: # With an integer in second position - ones_2 = numpy.ones(tensor_shape, dtype=numpy.int64) if tensor_shape != () else 1 + ones_2 = numpy.ones(tensor_shape, dtype=numpy.int32) if tensor_shape != () else 1 def get_function_to_trace(): return lambda x, y: fun(x + y, 4 * ones_2).astype(numpy.float64).astype(numpy.int32) @@ -326,13 +367,6 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape): input_list = [2, 42, 44] for input_ in input_list: - ones_input = ( - numpy.ones(tensor_shape, dtype=numpy.dtype(type(input_))) - if tensor_shape != () - else 1 - ) - input_ = input_ * ones_input - function_to_trace = get_function_to_trace() params_names = signature(function_to_trace).parameters.keys() @@ -350,15 +384,30 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape): assert fused_num_nodes < orig_num_nodes - input_ = numpy.int32(input_) + # Check that the call to the function or to the op_graph evaluation give the same + # result + tensor_diversifier = ( + # The following +1 in the range is to avoid to have 0's which is not in the + # domain definition of some of our functions + numpy.arange(1, numpy.product(tensor_shape) + 1, dtype=numpy.int32).reshape( + tensor_shape + ) + if tensor_shape != () + else 1 + ) + input_ = input_ * tensor_diversifier num_params = len(params_names) - inputs = (input_,) * num_params + assert num_params == 2 - function_result = function_to_trace(*inputs) - op_graph_result = op_graph(*inputs) + # Create inputs which are either of the form [x, x] or [x, y] + for j in range(4): + inputs = (input_, input_ + j) - assert check_results_are_equal(function_result, op_graph_result) + function_result = function_to_trace(*inputs) + op_graph_result = op_graph(*inputs) + + assert check_results_are_equal(function_result, op_graph_result) def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_shape): @@ -383,7 +432,9 @@ def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_ @pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC) -@pytest.mark.parametrize("tensor_shape", [(), (3, 1, 2)]) +@pytest.mark.parametrize( + "tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")] +) def test_ufunc_operations(fun, tensor_shape): """Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""