mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat: add management of boolean binary operators with a const scalar
refs #126 refs #529
This commit is contained in:
committed by
Benoit Chevallier
parent
f443b41cef
commit
e8114cc470
@@ -38,8 +38,10 @@ def simple_fuse_output(x):
|
||||
return x.astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
|
||||
def complex_fuse_indirect_input(function, x, y):
|
||||
"""Complex fuse"""
|
||||
def mix_x_and_y_intricately_and_call_f(function, x, y):
|
||||
"""Mix x and y in an intricated way, that can't be simplified by
|
||||
an optimizer eg, and then call function
|
||||
"""
|
||||
intermediate = x + y
|
||||
intermediate = intermediate + 2
|
||||
intermediate = intermediate.astype(numpy.float32)
|
||||
@@ -57,8 +59,8 @@ def complex_fuse_indirect_input(function, x, y):
|
||||
)
|
||||
|
||||
|
||||
def complex_fuse_direct_input(function, x, y):
|
||||
"""Complex fuse"""
|
||||
def mix_x_and_y_and_call_f(function, x, y):
|
||||
"""Mix x and y and then call function"""
|
||||
x_p_1 = x + 0.1
|
||||
x_p_2 = x + 0.2
|
||||
x_p_3 = function(x_p_1 + x_p_2)
|
||||
@@ -72,6 +74,21 @@ def complex_fuse_direct_input(function, x, y):
|
||||
)
|
||||
|
||||
|
||||
def mix_x_and_y_into_integer_and_call_f(function, x, y):
|
||||
"""Mix x and y but keep the entry to function as an integer"""
|
||||
x_p_1 = x + 1
|
||||
x_p_2 = x + 2
|
||||
x_p_3 = function(x_p_1 + x_p_2)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
y,
|
||||
(y + 4.7).astype(numpy.int32) + 3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,fused",
|
||||
[
|
||||
@@ -80,14 +97,14 @@ def complex_fuse_direct_input(function, x, y):
|
||||
pytest.param(simple_fuse_not_output, True, id="no_fuse"),
|
||||
pytest.param(simple_fuse_output, True, id="no_fuse"),
|
||||
pytest.param(
|
||||
lambda x, y: complex_fuse_indirect_input(numpy.rint, x, y),
|
||||
lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y),
|
||||
True,
|
||||
id="complex_fuse_indirect_input_with_rint",
|
||||
id="mix_x_and_y_intricately_and_call_f_with_rint",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: complex_fuse_direct_input(numpy.rint, x, y),
|
||||
lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y),
|
||||
True,
|
||||
id="complex_fuse_direct_input_with_rint",
|
||||
id="mix_x_and_y_and_call_f_with_rint",
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -152,13 +169,16 @@ def subtest_fuse_float_unary_operations_correctness(fun):
|
||||
# Some manipulation to avoid issues with domain of definitions of functions
|
||||
if fun == numpy.arccosh:
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
super_fun_list = [mix_x_and_y_and_call_f]
|
||||
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
input_list = [0, 0.1, 0.2]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
super_fun_list = [mix_x_and_y_and_call_f]
|
||||
elif fun == numpy.invert:
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [mix_x_and_y_into_integer_and_call_f]
|
||||
else:
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input]
|
||||
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
|
||||
|
||||
for super_fun in super_fun_list:
|
||||
|
||||
@@ -194,6 +214,7 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
|
||||
numpy.bitwise_or,
|
||||
numpy.bitwise_xor,
|
||||
numpy.gcd,
|
||||
numpy.invert,
|
||||
numpy.lcm,
|
||||
numpy.ldexp,
|
||||
numpy.left_shift,
|
||||
@@ -220,6 +241,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
# a float output even for functions which return an integer (eg, XOR), such
|
||||
# that our frontend always try to fuse them
|
||||
|
||||
# The .astype(numpy.float64) that we have in cases 1 and 3 is here to force
|
||||
# a float output even for functions which return a bool (eg, EQUAL), such
|
||||
# that our frontend always try to fuse them
|
||||
|
||||
# For bivariate functions: fix one of the inputs
|
||||
if i == 0:
|
||||
# With an integer in first position
|
||||
@@ -229,7 +254,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
elif i == 1:
|
||||
# With a float in first position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(2.3, x + y).astype(numpy.int32)
|
||||
return lambda x, y: fun(2.3, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
elif i == 2:
|
||||
# With an integer in second position
|
||||
@@ -239,7 +264,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
else:
|
||||
# With a float in second position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 5.7).astype(numpy.int32)
|
||||
return lambda x, y: fun(x + y, 5.7).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
input_list = [0, 2, 42, 44]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user