mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add management of boolean binary operators with a const scalar
refs #126 refs #529
This commit is contained in:
committed by
Benoit Chevallier
parent
f443b41cef
commit
e8114cc470
@@ -249,9 +249,17 @@ class NPTracer(BaseTracer):
|
||||
|
||||
# Supported functions are either univariate or bivariate for which one of the two
|
||||
# sources is a constant
|
||||
#
|
||||
# numpy.add, numpy.multiply and numpy.subtract are not there since already managed
|
||||
# by leveled operations
|
||||
#
|
||||
# numpy.conjugate is not there since working on complex numbers
|
||||
#
|
||||
# numpy.isnat is not there since it is about timings
|
||||
#
|
||||
# numpy.divmod, numpy.modf and numpy.frexp are not there since output two values
|
||||
LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [
|
||||
numpy.absolute,
|
||||
# numpy.add,
|
||||
numpy.arccos,
|
||||
numpy.arccosh,
|
||||
numpy.arcsin,
|
||||
@@ -264,14 +272,12 @@ class NPTracer(BaseTracer):
|
||||
numpy.bitwise_xor,
|
||||
numpy.cbrt,
|
||||
numpy.ceil,
|
||||
# numpy.conjugate,
|
||||
numpy.copysign,
|
||||
numpy.cos,
|
||||
numpy.cosh,
|
||||
numpy.deg2rad,
|
||||
numpy.degrees,
|
||||
# numpy.divmod,
|
||||
# numpy.equal,
|
||||
numpy.equal,
|
||||
numpy.exp,
|
||||
numpy.exp2,
|
||||
numpy.expm1,
|
||||
@@ -282,22 +288,20 @@ class NPTracer(BaseTracer):
|
||||
numpy.fmax,
|
||||
numpy.fmin,
|
||||
numpy.fmod,
|
||||
# numpy.frexp,
|
||||
numpy.gcd,
|
||||
# numpy.greater,
|
||||
# numpy.greater_equal,
|
||||
numpy.greater,
|
||||
numpy.greater_equal,
|
||||
numpy.heaviside,
|
||||
numpy.hypot,
|
||||
# numpy.invert,
|
||||
numpy.invert,
|
||||
numpy.isfinite,
|
||||
numpy.isinf,
|
||||
numpy.isnan,
|
||||
# numpy.isnat,
|
||||
numpy.lcm,
|
||||
numpy.ldexp,
|
||||
numpy.left_shift,
|
||||
# numpy.less,
|
||||
# numpy.less_equal,
|
||||
numpy.less,
|
||||
numpy.less_equal,
|
||||
numpy.log,
|
||||
numpy.log10,
|
||||
numpy.log1p,
|
||||
@@ -311,11 +315,9 @@ class NPTracer(BaseTracer):
|
||||
# numpy.matmul,
|
||||
numpy.maximum,
|
||||
numpy.minimum,
|
||||
# numpy.modf,
|
||||
# numpy.multiply,
|
||||
numpy.negative,
|
||||
numpy.nextafter,
|
||||
# numpy.not_equal,
|
||||
numpy.not_equal,
|
||||
numpy.positive,
|
||||
numpy.power,
|
||||
numpy.rad2deg,
|
||||
@@ -331,7 +333,6 @@ class NPTracer(BaseTracer):
|
||||
numpy.spacing,
|
||||
numpy.sqrt,
|
||||
numpy.square,
|
||||
# numpy.subtract,
|
||||
numpy.tan,
|
||||
numpy.tanh,
|
||||
numpy.true_divide,
|
||||
|
||||
@@ -38,8 +38,10 @@ def simple_fuse_output(x):
|
||||
return x.astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
|
||||
def complex_fuse_indirect_input(function, x, y):
|
||||
"""Complex fuse"""
|
||||
def mix_x_and_y_intricately_and_call_f(function, x, y):
|
||||
"""Mix x and y in an intricated way, that can't be simplified by
|
||||
an optimizer eg, and then call function
|
||||
"""
|
||||
intermediate = x + y
|
||||
intermediate = intermediate + 2
|
||||
intermediate = intermediate.astype(numpy.float32)
|
||||
@@ -57,8 +59,8 @@ def complex_fuse_indirect_input(function, x, y):
|
||||
)
|
||||
|
||||
|
||||
def complex_fuse_direct_input(function, x, y):
|
||||
"""Complex fuse"""
|
||||
def mix_x_and_y_and_call_f(function, x, y):
|
||||
"""Mix x and y and then call function"""
|
||||
x_p_1 = x + 0.1
|
||||
x_p_2 = x + 0.2
|
||||
x_p_3 = function(x_p_1 + x_p_2)
|
||||
@@ -72,6 +74,21 @@ def complex_fuse_direct_input(function, x, y):
|
||||
)
|
||||
|
||||
|
||||
def mix_x_and_y_into_integer_and_call_f(function, x, y):
|
||||
"""Mix x and y but keep the entry to function as an integer"""
|
||||
x_p_1 = x + 1
|
||||
x_p_2 = x + 2
|
||||
x_p_3 = function(x_p_1 + x_p_2)
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
y,
|
||||
(y + 4.7).astype(numpy.int32) + 3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,fused",
|
||||
[
|
||||
@@ -80,14 +97,14 @@ def complex_fuse_direct_input(function, x, y):
|
||||
pytest.param(simple_fuse_not_output, True, id="no_fuse"),
|
||||
pytest.param(simple_fuse_output, True, id="no_fuse"),
|
||||
pytest.param(
|
||||
lambda x, y: complex_fuse_indirect_input(numpy.rint, x, y),
|
||||
lambda x, y: mix_x_and_y_intricately_and_call_f(numpy.rint, x, y),
|
||||
True,
|
||||
id="complex_fuse_indirect_input_with_rint",
|
||||
id="mix_x_and_y_intricately_and_call_f_with_rint",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: complex_fuse_direct_input(numpy.rint, x, y),
|
||||
lambda x, y: mix_x_and_y_and_call_f(numpy.rint, x, y),
|
||||
True,
|
||||
id="complex_fuse_direct_input_with_rint",
|
||||
id="mix_x_and_y_and_call_f_with_rint",
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -152,13 +169,16 @@ def subtest_fuse_float_unary_operations_correctness(fun):
|
||||
# Some manipulation to avoid issues with domain of definitions of functions
|
||||
if fun == numpy.arccosh:
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
super_fun_list = [mix_x_and_y_and_call_f]
|
||||
elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]:
|
||||
input_list = [0, 0.1, 0.2]
|
||||
super_fun_list = [complex_fuse_direct_input]
|
||||
super_fun_list = [mix_x_and_y_and_call_f]
|
||||
elif fun == numpy.invert:
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [mix_x_and_y_into_integer_and_call_f]
|
||||
else:
|
||||
input_list = [0, 2, 42, 44]
|
||||
super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input]
|
||||
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
|
||||
|
||||
for super_fun in super_fun_list:
|
||||
|
||||
@@ -194,6 +214,7 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
|
||||
numpy.bitwise_or,
|
||||
numpy.bitwise_xor,
|
||||
numpy.gcd,
|
||||
numpy.invert,
|
||||
numpy.lcm,
|
||||
numpy.ldexp,
|
||||
numpy.left_shift,
|
||||
@@ -220,6 +241,10 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
# a float output even for functions which return an integer (eg, XOR), such
|
||||
# that our frontend always try to fuse them
|
||||
|
||||
# The .astype(numpy.float64) that we have in cases 1 and 3 is here to force
|
||||
# a float output even for functions which return a bool (eg, EQUAL), such
|
||||
# that our frontend always try to fuse them
|
||||
|
||||
# For bivariate functions: fix one of the inputs
|
||||
if i == 0:
|
||||
# With an integer in first position
|
||||
@@ -229,7 +254,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
elif i == 1:
|
||||
# With a float in first position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(2.3, x + y).astype(numpy.int32)
|
||||
return lambda x, y: fun(2.3, x + y).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
elif i == 2:
|
||||
# With an integer in second position
|
||||
@@ -239,7 +264,7 @@ def subtest_fuse_float_binary_operations_correctness(fun):
|
||||
else:
|
||||
# With a float in second position
|
||||
def get_function_to_trace():
|
||||
return lambda x, y: fun(x + y, 5.7).astype(numpy.int32)
|
||||
return lambda x, y: fun(x + y, 5.7).astype(numpy.float64).astype(numpy.int32)
|
||||
|
||||
input_list = [0, 2, 42, 44]
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL = set(
|
||||
numpy.isinf,
|
||||
numpy.isnan,
|
||||
numpy.signbit,
|
||||
numpy.logical_not,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -406,15 +407,17 @@ def test_tracing_astype(
|
||||
),
|
||||
],
|
||||
)
|
||||
# numpy.logical_not is removed from the following test since it is expecting inputs which are
|
||||
# integer only, as opposed to other functions in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace_def",
|
||||
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1 if f != numpy.logical_not],
|
||||
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1],
|
||||
)
|
||||
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
|
||||
# numpy.invert is expecting inputs which are integer only
|
||||
if function_to_trace_def == numpy.invert and not isinstance(inputs["x"].dtype, Integer):
|
||||
return
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
|
||||
Reference in New Issue
Block a user