feat: add more bivariate operations

which are supported if one of the operands is a constant
refs #126
This commit is contained in:
Benoit Chevallier-Mames
2021-10-06 18:02:35 +02:00
committed by Benoit Chevallier
parent ab1f0f3c4a
commit 2da3895f1a
3 changed files with 74 additions and 32 deletions

View File

@@ -189,16 +189,42 @@ def subtest_fuse_float_unary_operations_correctness(fun):
assert function_to_trace(*inputs) == op_graph(*inputs)
LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
numpy.bitwise_and,
numpy.bitwise_or,
numpy.bitwise_xor,
numpy.gcd,
numpy.lcm,
numpy.ldexp,
numpy.left_shift,
numpy.logical_and,
numpy.logical_not,
numpy.logical_or,
numpy.logical_xor,
numpy.remainder,
numpy.right_shift,
}
def subtest_fuse_float_binary_operations_correctness(fun):
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
for i in range(4):
# Know if the function is defined for integer inputs
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
if i not in [0, 2]:
continue
# The .astype(numpy.float64) that we have in cases 0 and 2 is here to force
# a float output even for functions which return an integer (eg, XOR), such
# that our frontend always try to fuse them
# For bivariate functions: fix one of the inputs
if i == 0:
# With an integer in first position
def get_function_to_trace():
return lambda x, y: fun(3, x + y).astype(numpy.int32)
return lambda x, y: fun(3, x + y).astype(numpy.float64).astype(numpy.int32)
elif i == 1:
# With a float in first position
@@ -208,7 +234,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
elif i == 2:
# With an integer in second position
def get_function_to_trace():
return lambda x, y: fun(x + y, 4).astype(numpy.int32)
return lambda x, y: fun(x + y, 4).astype(numpy.float64).astype(numpy.int32)
else:
# With a float in second position
@@ -217,6 +243,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
input_list = [0, 2, 42, 44]
# Domain of definition
if fun in [numpy.true_divide, numpy.remainder, numpy.floor_divide, numpy.fmod]:
input_list = [2, 42, 44]
for input_ in input_list:
function_to_trace = get_function_to_trace()