feat: add management of boolean binary operators with a const scalar

refs #126
refs #529
This commit is contained in:
Benoit Chevallier-Mames
2021-10-07 16:38:25 +02:00
committed by Benoit Chevallier
parent f443b41cef
commit e8114cc470
3 changed files with 60 additions and 31 deletions

View File

@@ -38,8 +38,10 @@ def simple_fuse_output(x):
return x.astype(numpy.float64).astype(numpy.int32)
def complex_fuse_indirect_input(function, x, y):
"""Complex fuse"""
def mix_x_and_y_intricately_and_call_f(function, x, y):
"""Mix x and y in an intricated way, that can't be simplified by
an optimizer eg, and then call function
"""
intermediate = x + y
intermediate = intermediate + 2
intermediate = intermediate.astype(numpy.float32)
@@ -57,8 +59,8 @@ def complex_fuse_indirect_input(function, x, y):
)
def complex_fuse_direct_input(function, x, y):
"""Complex fuse"""
def mix_x_and_y_and_call_f(function, x, y):
"""Mix x and y and then call function"""
x_p_1 = x + 0.1
x_p_2 = x + 0.2
x_p_3 = function(x_p_1 + x_p_2)
@@ -72,6 +74,21 @@ def complex_fuse_direct_input(function, x, y):
)
def mix_x_and_y_into_integer_and_call_f(function, x, y):
"""Mix x and y but keep the entry to function as an integer"""
x_p_1 = x + 1
x_p_2 = x + 2
x_p_3 = function(x_p_1 + x_p_2)
return (
x_p_3.astype(numpy.int32),
x_p_2.astype(numpy.int32),
(x_p_2 + 3).astype(numpy.int32),
x_p_3.astype(numpy.int32) + 67,
y,
(y + 4.7).astype(numpy.int32) + 3,
)
@pytest.mark.parametrize(
"function_to_trace,fused",
[
@@ -80,14 +97,14 @@ def complex_fuse_direct_input(function, x, y):
pytest.param(simple_fuse_not_output, True, id="no_fuse"),
pytest.param(simple_fuse_output, True, id="no_fuse"),
pytest.param(
lambda x, y: complex_fuse_indirect_input(numpy.rint, x, y),
lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y),
True,
id="complex_fuse_indirect_input_with_rint",
id="mix_x_and_y_intricately_and_call_f_with_rint",
),
pytest.param(
lambda x, y: complex_fuse_direct_input(numpy.rint, x, y),
lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y),
True,
id="complex_fuse_direct_input_with_rint",
id="mix_x_and_y_and_call_f_with_rint",
),
],
)
@@ -152,13 +169,16 @@ def subtest_fuse_float_unary_operations_correctness(fun):
# Some manipulation to avoid issues with domain of definitions of functions
if fun == numpy.arccosh:
input_list = [1, 2, 42, 44]
super_fun_list = [complex_fuse_direct_input]
super_fun_list = [mix_x_and_y_and_call_f]
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
input_list = [0, 0.1, 0.2]
super_fun_list = [complex_fuse_direct_input]
super_fun_list = [mix_x_and_y_and_call_f]
elif fun == numpy.invert:
input_list = [1, 2, 42, 44]
super_fun_list = [mix_x_and_y_into_integer_and_call_f]
else:
input_list = [0, 2, 42, 44]
super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input]
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
for super_fun in super_fun_list:
@@ -194,6 +214,7 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
numpy.bitwise_or,
numpy.bitwise_xor,
numpy.gcd,
numpy.invert,
numpy.lcm,
numpy.ldexp,
numpy.left_shift,
@@ -220,6 +241,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
# a float output even for functions which return an integer (eg, XOR), such
# that our frontend always try to fuse them
# The .astype(numpy.float64) that we have in cases 1 and 3 is here to force
# a float output even for functions which return a bool (eg, EQUAL), such
# that our frontend always try to fuse them
# For bivariate functions: fix one of the inputs
if i == 0:
# With an integer in first position
@@ -229,7 +254,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
elif i == 1:
# With a float in first position
def get_function_to_trace():
return lambda x, y: fun(2.3, x + y).astype(numpy.int32)
return lambda x, y: fun(2.3, x + y).astype(numpy.float64).astype(numpy.int32)
elif i == 2:
# With an integer in second position
@@ -239,7 +264,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
else:
# With a float in second position
def get_function_to_trace():
return lambda x, y: fun(x + y, 5.7).astype(numpy.int32)
return lambda x, y: fun(x + y, 5.7).astype(numpy.float64).astype(numpy.int32)
input_list = [0, 2, 42, 44]