mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add more bivariate operations
which are supported if one of the operands is a constant refs #126
This commit is contained in:
committed by
Benoit Chevallier
parent
ab1f0f3c4a
commit
2da3895f1a
@@ -189,16 +189,42 @@ def subtest_fuse_float_unary_operations_correctness(fun):
|
||||
assert function_to_trace(*inputs) == op_graph(*inputs)
|
||||
|
||||
|
||||
LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
|
||||
numpy.bitwise_and,
|
||||
numpy.bitwise_or,
|
||||
numpy.bitwise_xor,
|
||||
numpy.gcd,
|
||||
numpy.lcm,
|
||||
numpy.ldexp,
|
||||
numpy.left_shift,
|
||||
numpy.logical_and,
|
||||
numpy.logical_not,
|
||||
numpy.logical_or,
|
||||
numpy.logical_xor,
|
||||
numpy.remainder,
|
||||
numpy.right_shift,
|
||||
}
|
||||
|
||||
|
||||
def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
"""Test a binary functions with fuse_float_operations, with a constant as a source."""
|
||||
|
||||
for i in range(4):
|
||||
|
||||
# Know if the function is defined for integer inputs
|
||||
if fun in LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES:
|
||||
if i not in [0, 2]:
|
||||
continue
|
||||
|
||||
# The .astype(numpy.float64) that we have in cases 0 and 2 is here to force
|
||||
# a float output even for functions which return an integer (eg, XOR), such
|
||||
# that our frontend always try to fuse them
|
||||
|
||||
# For bivariate functions: fix one of the inputs
|
||||
if i == 0:
|
||||
# With an integer in first position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(3, x + y).astype(numpy.int32)
|
||||
return lambda x, y: fun(3, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
elif i == 1:
|
||||
# With a float in first position
|
||||
@@ -208,7 +234,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
elif i == 2:
|
||||
# With an integer in second position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 4).astype(numpy.int32)
|
||||
return lambda x, y: fun(x + y, 4).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
else:
|
||||
# With a float in second position
|
||||
@@ -217,6 +243,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
|
||||
input_list = [0, 2, 42, 44]
|
||||
|
||||
# Domain of definition
|
||||
if fun in [numpy.true_divide, numpy.remainder, numpy.floor_divide, numpy.fmod]:
|
||||
input_list = [2, 42, 44]
|
||||
|
||||
for input_ in input_list:
|
||||
|
||||
function_to_trace = get_function_to_trace()
|
||||
|
||||
Reference in New Issue
Block a user