feat: add management of boolean binary operators with a const scalar

refs #126
refs #529
This commit is contained in:
Benoit Chevallier-Mames
2021-10-07 16:38:25 +02:00
committed by Benoit Chevallier
parent f443b41cef
commit e8114cc470
3 changed files with 60 additions and 31 deletions

View File

@@ -249,9 +249,17 @@ class NPTracer(BaseTracer):
# Supported functions are either univariate or bivariate for which one of the two
# sources is a constant
#
# numpy.add, numpy.multiply and numpy.subtract are not there since already managed
# by leveled operations
#
# numpy.conjugate is not there since working on complex numbers
#
# numpy.isnat is not there since it is about timings
#
# numpy.divmod, numpy.modf and numpy.frexp are not there since output two values
LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [
numpy.absolute,
# numpy.add,
numpy.arccos,
numpy.arccosh,
numpy.arcsin,
@@ -264,14 +272,12 @@ class NPTracer(BaseTracer):
numpy.bitwise_xor,
numpy.cbrt,
numpy.ceil,
# numpy.conjugate,
numpy.copysign,
numpy.cos,
numpy.cosh,
numpy.deg2rad,
numpy.degrees,
# numpy.divmod,
# numpy.equal,
numpy.equal,
numpy.exp,
numpy.exp2,
numpy.expm1,
@@ -282,22 +288,20 @@ class NPTracer(BaseTracer):
numpy.fmax,
numpy.fmin,
numpy.fmod,
# numpy.frexp,
numpy.gcd,
# numpy.greater,
# numpy.greater_equal,
numpy.greater,
numpy.greater_equal,
numpy.heaviside,
numpy.hypot,
# numpy.invert,
numpy.invert,
numpy.isfinite,
numpy.isinf,
numpy.isnan,
# numpy.isnat,
numpy.lcm,
numpy.ldexp,
numpy.left_shift,
# numpy.less,
# numpy.less_equal,
numpy.less,
numpy.less_equal,
numpy.log,
numpy.log10,
numpy.log1p,
@@ -311,11 +315,9 @@ class NPTracer(BaseTracer):
# numpy.matmul,
numpy.maximum,
numpy.minimum,
# numpy.modf,
# numpy.multiply,
numpy.negative,
numpy.nextafter,
# numpy.not_equal,
numpy.not_equal,
numpy.positive,
numpy.power,
numpy.rad2deg,
@@ -331,7 +333,6 @@ class NPTracer(BaseTracer):
numpy.spacing,
numpy.sqrt,
numpy.square,
# numpy.subtract,
numpy.tan,
numpy.tanh,
numpy.true_divide,

View File

@@ -38,8 +38,10 @@ def simple_fuse_output(x):
return x.astype(numpy.float64).astype(numpy.int32)
def complex_fuse_indirect_input(function, x, y):
"""Complex fuse"""
def mix_x_and_y_intricately_and_call_f(function, x, y):
"""Mix x and y in an intricated way, that can't be simplified by
an optimizer eg, and then call function
"""
intermediate = x + y
intermediate = intermediate + 2
intermediate = intermediate.astype(numpy.float32)
@@ -57,8 +59,8 @@ def complex_fuse_indirect_input(function, x, y):
)
def complex_fuse_direct_input(function, x, y):
"""Complex fuse"""
def mix_x_and_y_and_call_f(function, x, y):
"""Mix x and y and then call function"""
x_p_1 = x + 0.1
x_p_2 = x + 0.2
x_p_3 = function(x_p_1 + x_p_2)
@@ -72,6 +74,21 @@ def complex_fuse_direct_input(function, x, y):
)
def mix_x_and_y_into_integer_and_call_f(function, x, y):
"""Mix x and y but keep the entry to function as an integer"""
x_p_1 = x + 1
x_p_2 = x + 2
x_p_3 = function(x_p_1 + x_p_2)
return (
x_p_3.astype(numpy.int32),
x_p_2.astype(numpy.int32),
(x_p_2 + 3).astype(numpy.int32),
x_p_3.astype(numpy.int32) + 67,
y,
(y + 4.7).astype(numpy.int32) + 3,
)
@pytest.mark.parametrize(
"function_to_trace,fused",
[
@@ -80,14 +97,14 @@ def complex_fuse_direct_input(function, x, y):
pytest.param(simple_fuse_not_output, True, id="no_fuse"),
pytest.param(simple_fuse_output, True, id="no_fuse"),
pytest.param(
lambda x, y: complex_fuse_indirect_input(numpy.rint, x, y),
lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y),
True,
id="complex_fuse_indirect_input_with_rint",
id="mix_x_and_y_intricately_and_call_f_with_rint",
),
pytest.param(
lambda x, y: complex_fuse_direct_input(numpy.rint, x, y),
lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y),
True,
id="complex_fuse_direct_input_with_rint",
id="mix_x_and_y_and_call_f_with_rint",
),
],
)
@@ -152,13 +169,16 @@ def subtest_fuse_float_unary_operations_correctness(fun):
# Some manipulation to avoid issues with domain of definitions of functions
if fun == numpy.arccosh:
input_list = [1, 2, 42, 44]
super_fun_list = [complex_fuse_direct_input]
super_fun_list = [mix_x_and_y_and_call_f]
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
input_list = [0, 0.1, 0.2]
super_fun_list = [complex_fuse_direct_input]
super_fun_list = [mix_x_and_y_and_call_f]
elif fun == numpy.invert:
input_list = [1, 2, 42, 44]
super_fun_list = [mix_x_and_y_into_integer_and_call_f]
else:
input_list = [0, 2, 42, 44]
super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input]
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
for super_fun in super_fun_list:
@@ -194,6 +214,7 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
numpy.bitwise_or,
numpy.bitwise_xor,
numpy.gcd,
numpy.invert,
numpy.lcm,
numpy.ldexp,
numpy.left_shift,
@@ -220,6 +241,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
# a float output even for functions which return an integer (eg, XOR), such
# that our frontend always try to fuse them
# The .astype(numpy.float64) that we have in cases 1 and 3 is here to force
# a float output even for functions which return a bool (eg, EQUAL), such
# that our frontend always try to fuse them
# For bivariate functions: fix one of the inputs
if i == 0:
# With an integer in first position
@@ -229,7 +254,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
elif i == 1:
# With a float in first position
def get_function_to_trace():
return lambda x, y: fun(2.3, x + y).astype(numpy.int32)
return lambda x, y: fun(2.3, x + y).astype(numpy.float64).astype(numpy.int32)
elif i == 2:
# With an integer in second position
@@ -239,7 +264,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
else:
# With a float in second position
def get_function_to_trace():
return lambda x, y: fun(x + y, 5.7).astype(numpy.int32)
return lambda x, y: fun(x + y, 5.7).astype(numpy.float64).astype(numpy.int32)
input_list = [0, 2, 42, 44]

View File

@@ -62,6 +62,7 @@ LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL = set(
numpy.isinf,
numpy.isnan,
numpy.signbit,
numpy.logical_not,
]
)
@@ -406,15 +407,17 @@ def test_tracing_astype(
),
],
)
# numpy.logical_not is removed from the following test since it is expecting inputs which are
# integer only, as opposed to other functions in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
@pytest.mark.parametrize(
"function_to_trace_def",
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1 if f != numpy.logical_not],
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1],
)
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
"""Function to trace supported numpy ufuncs"""
# numpy.invert is expecting inputs which are integer only
if function_to_trace_def == numpy.invert and not isinstance(inputs["x"].dtype, Integer):
return
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it