feat: remove support for np.invert

remove support for np.invert and propose to the user to use bitwise_xor instead, because of impossibilities with float fusing
closes #658
This commit is contained in:
Benoit Chevallier-Mames
2021-10-18 11:53:02 +02:00
committed by Benoit Chevallier
parent fb0564eea2
commit 7bf2f09615
5 changed files with 41 additions and 20 deletions

View File

@@ -393,15 +393,6 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration):
((0, 5), (0, 5)),
default_compilation_configuration,
)
elif ufunc in [
numpy.invert,
]:
# Can't make it work, to have a fusable function
# TODO: fixme
pass
# subtest_compile_and_run_unary_ufunc_correctness(
# ufunc, mix_x_and_y_and_call_f_with_integer_inputs, ((0, 5), (0, 5))
# )
elif ufunc in [
numpy.arccosh,
numpy.log,

View File

@@ -381,6 +381,35 @@ def test_tracing_astype(
assert expected_output == evaluated_output
@pytest.mark.parametrize(
"inputs",
[
pytest.param(
{"x": EncryptedScalar(Integer(32, is_signed=True))},
),
],
)
@pytest.mark.parametrize(
"function_to_trace",
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint is not happy
# with it
# pylint: disable=unnecessary-lambda
[lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)],
# pylint: enable=unnecessary-lambda
)
def test_trace_numpy_fails_for_invert(inputs, function_to_trace):
"""Check we catch calls to numpy.invert and tell user to change their code"""
with pytest.raises(RuntimeError) as excinfo:
tracing.trace_numpy_function(function_to_trace, inputs)
assert (
"NPTracer does not manage the following func: invert. Please replace by calls to "
"bitwise_xor with appropriate mask" in str(excinfo.value)
)
@pytest.mark.parametrize(
"inputs,expected_output_node",
[
@@ -414,10 +443,6 @@ def test_tracing_astype(
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
"""Function to trace supported numpy ufuncs"""
# numpy.invert is expecting inputs which are integer only
if function_to_trace_def == numpy.invert and not isinstance(inputs["x"].dtype, Integer):
return
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it