feat: remove support for np.invert

remove support for np.invert and propose to the user to use bitwise_xor instead, because of impossibilities with float fusing
closes #658
This commit is contained in:
Benoit Chevallier-Mames
2021-10-18 11:53:02 +02:00
committed by Benoit Chevallier
parent fb0564eea2
commit 7bf2f09615
5 changed files with 41 additions and 20 deletions

View File

@@ -110,6 +110,15 @@ class NPTracer(BaseTracer):
Callable: the tracing function that needs to be called to trace func
"""
tracing_func: Optional[Callable]
# numpy.invert is not great in term of types it supports, so we've decided not to support it
# and to propose to the user to use numpy.bitwise_not
if func == numpy.invert:
raise RuntimeError(
f"NPTracer does not manage the following func: {func.__name__}. Please replace by "
f"calls to bitwise_xor with appropriate mask"
)
if isinstance(func, numpy.ufunc):
tracing_func = NPTracer.UFUNC_ROUTING.get(func, None)
else:
@@ -266,6 +275,9 @@ class NPTracer(BaseTracer):
# numpy.isnat is not there since it is about timings
#
# numpy.divmod, numpy.modf and numpy.frexp are not there since output two values
#
# numpy.invert (as known as numpy.bitwise_not) is not here, because it has strange input type.
# We ask the user to replace bitwise_xor instead
LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [
numpy.absolute,
numpy.arccos,
@@ -301,7 +313,6 @@ class NPTracer(BaseTracer):
numpy.greater_equal,
numpy.heaviside,
numpy.hypot,
numpy.invert,
numpy.isfinite,
numpy.isinf,
numpy.isnan,

View File

@@ -48,7 +48,6 @@ List of supported unary functions:
- expm1
- fabs
- floor
- invert
- isfinite
- isinf
- isnan

View File

@@ -224,10 +224,6 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
# Not too large values to avoid overflows
input_list = [1, 2, 5, 11]
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
elif fun == numpy.invert:
# 0 is not in the domain of definition + expect integer inputs
input_list = [1, 2, 42, 44]
super_fun_list = [mix_x_and_y_into_integer_and_call_f]
else:
# Regular case
input_list = [0, 2, 42, 44]
@@ -307,7 +303,6 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
numpy.bitwise_or,
numpy.bitwise_xor,
numpy.gcd,
numpy.invert,
numpy.lcm,
numpy.ldexp,
numpy.left_shift,

View File

@@ -393,15 +393,6 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration):
((0, 5), (0, 5)),
default_compilation_configuration,
)
elif ufunc in [
numpy.invert,
]:
# Can't make it work, to have a fusable function
# TODO: fixme
pass
# subtest_compile_and_run_unary_ufunc_correctness(
# ufunc, mix_x_and_y_and_call_f_with_integer_inputs, ((0, 5), (0, 5))
# )
elif ufunc in [
numpy.arccosh,
numpy.log,

View File

@@ -381,6 +381,35 @@ def test_tracing_astype(
assert expected_output == evaluated_output
@pytest.mark.parametrize(
"inputs",
[
pytest.param(
{"x": EncryptedScalar(Integer(32, is_signed=True))},
),
],
)
@pytest.mark.parametrize(
"function_to_trace",
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint is not happy
# with it
# pylint: disable=unnecessary-lambda
[lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)],
# pylint: enable=unnecessary-lambda
)
def test_trace_numpy_fails_for_invert(inputs, function_to_trace):
"""Check we catch calls to numpy.invert and tell user to change their code"""
with pytest.raises(RuntimeError) as excinfo:
tracing.trace_numpy_function(function_to_trace, inputs)
assert (
"NPTracer does not manage the following func: invert. Please replace by calls to "
"bitwise_xor with appropriate mask" in str(excinfo.value)
)
@pytest.mark.parametrize(
"inputs,expected_output_node",
[
@@ -414,10 +443,6 @@ def test_tracing_astype(
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
"""Function to trace supported numpy ufuncs"""
# numpy.invert is expecting inputs which are integer only
if function_to_trace_def == numpy.invert and not isinstance(inputs["x"].dtype, Integer):
return
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it