feat: add more bivariate operations

which are supported if one of the operands is a constant
refs #126
This commit is contained in:
Benoit Chevallier-Mames
2021-10-06 18:02:35 +02:00
committed by Benoit Chevallier
parent ab1f0f3c4a
commit 2da3895f1a
3 changed files with 74 additions and 32 deletions

View File

@@ -259,13 +259,13 @@ class NPTracer(BaseTracer):
numpy.arctan,
numpy.arctan2,
numpy.arctanh,
# numpy.bitwise_and,
# numpy.bitwise_or,
# numpy.bitwise_xor,
numpy.bitwise_and,
numpy.bitwise_or,
numpy.bitwise_xor,
numpy.cbrt,
numpy.ceil,
# numpy.conjugate,
# numpy.copysign,
numpy.copysign,
numpy.cos,
numpy.cosh,
numpy.deg2rad,
@@ -278,51 +278,51 @@ class NPTracer(BaseTracer):
numpy.fabs,
numpy.float_power,
numpy.floor,
# numpy.floor_divide,
# numpy.fmax,
# numpy.fmin,
# numpy.fmod,
numpy.floor_divide,
numpy.fmax,
numpy.fmin,
numpy.fmod,
# numpy.frexp,
# numpy.gcd,
numpy.gcd,
# numpy.greater,
# numpy.greater_equal,
# numpy.heaviside,
# numpy.hypot,
numpy.heaviside,
numpy.hypot,
# numpy.invert,
numpy.isfinite,
numpy.isinf,
numpy.isnan,
# numpy.isnat,
# numpy.lcm,
# numpy.ldexp,
# numpy.left_shift,
numpy.lcm,
numpy.ldexp,
numpy.left_shift,
# numpy.less,
# numpy.less_equal,
numpy.log,
numpy.log10,
numpy.log1p,
numpy.log2,
# numpy.logaddexp,
# numpy.logaddexp2,
# numpy.logical_and,
# numpy.logical_not,
# numpy.logical_or,
# numpy.logical_xor,
numpy.logaddexp,
numpy.logaddexp2,
numpy.logical_and,
numpy.logical_not,
numpy.logical_or,
numpy.logical_xor,
# numpy.matmul,
# numpy.maximum,
# numpy.minimum,
numpy.maximum,
numpy.minimum,
# numpy.modf,
# numpy.multiply,
numpy.negative,
# numpy.nextafter,
numpy.nextafter,
# numpy.not_equal,
numpy.positive,
# numpy.power,
numpy.power,
numpy.rad2deg,
numpy.radians,
numpy.reciprocal,
# numpy.remainder,
# numpy.right_shift,
numpy.remainder,
numpy.right_shift,
numpy.rint,
numpy.sign,
numpy.signbit,
@@ -334,7 +334,7 @@ class NPTracer(BaseTracer):
# numpy.subtract,
numpy.tan,
numpy.tanh,
# numpy.true_divide,
numpy.true_divide,
numpy.trunc,
]
@@ -378,9 +378,18 @@ NPTracer.UFUNC_ROUTING = {
fun: _get_unary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 1
}
NPTracer.UFUNC_ROUTING[numpy.arctan2] = _get_binary_fun(numpy.arctan2)
NPTracer.UFUNC_ROUTING[numpy.float_power] = _get_binary_fun(numpy.float_power)
NPTracer.UFUNC_ROUTING.update(
{fun: _get_binary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 2}
)
list_of_not_supported = [
(ufunc.__name__, ufunc.nin)
for ufunc in NPTracer.LIST_OF_SUPPORTED_UFUNC
if ufunc.nin not in [1, 2]
]
custom_assert(len(list_of_not_supported) == 0, f"Not supported nin's, {list_of_not_supported}")
del list_of_not_supported
# We are adding initial support for `np.array(...)` +,-,* `BaseTracer`
# (note that this is not the proper complete handling of these functions)

View File

@@ -189,16 +189,42 @@ def subtest_fuse_float_unary_operations_correctness(fun):
assert function_to_trace(*inputs) == op_graph(*inputs)
LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
numpy.bitwise_and,
numpy.bitwise_or,
numpy.bitwise_xor,
numpy.gcd,
numpy.lcm,
numpy.ldexp,
numpy.left_shift,
numpy.logical_and,
numpy.logical_not,
numpy.logical_or,
numpy.logical_xor,
numpy.remainder,
numpy.right_shift,
}
def subtest_fuse_float_binary_operations_correctness(fun):
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
for i in range(4):
# Know if the function is defined for integer inputs
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
if i not in [0, 2]:
continue
# The .astype(numpy.float64) that we have in cases 0 and 2 is here to force
# a float output even for functions which return an integer (eg, XOR), such
# that our frontend always try to fuse them
# For bivariate functions: fix one of the inputs
if i == 0:
# With an integer in first position
def get_function_to_trace():
return lambda x, y: fun(3, x + y).astype(numpy.int32)
return lambda x, y: fun(3, x + y).astype(numpy.float64).astype(numpy.int32)
elif i == 1:
# With a float in first position
@@ -208,7 +234,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
elif i == 2:
# With an integer in second position
def get_function_to_trace():
return lambda x, y: fun(x + y, 4).astype(numpy.int32)
return lambda x, y: fun(x + y, 4).astype(numpy.float64).astype(numpy.int32)
else:
# With a float in second position
@@ -217,6 +243,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
input_list = [0, 2, 42, 44]
# Domain of definition
if fun in [numpy.true_divide, numpy.remainder, numpy.floor_divide, numpy.fmod]:
input_list = [2, 42, 44]
for input_ in input_list:
function_to_trace = get_function_to_trace()

View File

@@ -406,8 +406,11 @@ def test_tracing_astype(
),
],
)
# numpy.logical_not is removed from the following test since it is expecting inputs which are
# integer only, as opposed to other functions in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
@pytest.mark.parametrize(
"function_to_trace_def", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1]
"function_to_trace_def",
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1 if f != numpy.logical_not],
)
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
"""Function to trace supported numpy ufuncs"""