From 7bf2f096155062ac7920041a802c507a48da481e Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Mon, 18 Oct 2021 11:53:02 +0200 Subject: [PATCH] feat: remove support for np.invert remove support for np.invert and propose to the user to use bitwise_xor instead, because of impossibilities with float fusing closes #658 --- concrete/numpy/tracing.py | 13 +++++++- .../tutorial/WORKING_WITH_FLOATING_POINTS.md | 1 - .../common/optimization/test_float_fusing.py | 5 --- tests/numpy/test_compile.py | 9 ----- tests/numpy/test_tracing.py | 33 ++++++++++++++++--- 5 files changed, 41 insertions(+), 20 deletions(-) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 00e1edfa6..f8316dc89 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -110,6 +110,15 @@ class NPTracer(BaseTracer): Callable: the tracing function that needs to be called to trace func """ tracing_func: Optional[Callable] + + # numpy.invert is not great in term of types it supports, so we've decided not to support it + # and to propose to the user to use numpy.bitwise_not + if func == numpy.invert: + raise RuntimeError( + f"NPTracer does not manage the following func: {func.__name__}. Please replace by " + f"calls to bitwise_xor with appropriate mask" + ) + if isinstance(func, numpy.ufunc): tracing_func = NPTracer.UFUNC_ROUTING.get(func, None) else: @@ -266,6 +275,9 @@ class NPTracer(BaseTracer): # numpy.isnat is not there since it is about timings # # numpy.divmod, numpy.modf and numpy.frexp are not there since output two values + # + # numpy.invert (as known as numpy.bitwise_not) is not here, because it has strange input type. + # We ask the user to replace bitwise_xor instead LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [ numpy.absolute, numpy.arccos, @@ -301,7 +313,6 @@ class NPTracer(BaseTracer): numpy.greater_equal, numpy.heaviside, numpy.hypot, - numpy.invert, numpy.isfinite, numpy.isinf, numpy.isnan, diff --git a/docs/user/tutorial/WORKING_WITH_FLOATING_POINTS.md b/docs/user/tutorial/WORKING_WITH_FLOATING_POINTS.md index afaaeb4d4..32b54d454 100644 --- a/docs/user/tutorial/WORKING_WITH_FLOATING_POINTS.md +++ b/docs/user/tutorial/WORKING_WITH_FLOATING_POINTS.md @@ -48,7 +48,6 @@ List of supported unary functions: - expm1 - fabs - floor -- invert - isfinite - isinf - isnan diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 0b715eef6..e9f13c3eb 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -224,10 +224,6 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape): # Not too large values to avoid overflows input_list = [1, 2, 5, 11] super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f] - elif fun == numpy.invert: - # 0 is not in the domain of definition + expect integer inputs - input_list = [1, 2, 42, 44] - super_fun_list = [mix_x_and_y_into_integer_and_call_f] else: # Regular case input_list = [0, 2, 42, 44] @@ -307,7 +303,6 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = { numpy.bitwise_or, numpy.bitwise_xor, numpy.gcd, - numpy.invert, numpy.lcm, numpy.ldexp, numpy.left_shift, diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 8bc025f86..51d96f64a 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -393,15 +393,6 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration): ((0, 5), (0, 5)), default_compilation_configuration, ) - elif ufunc in [ - numpy.invert, - ]: - # Can't make it work, to have a fusable function - # TODO: fixme - pass - # subtest_compile_and_run_unary_ufunc_correctness( - # ufunc, mix_x_and_y_and_call_f_with_integer_inputs, ((0, 5), (0, 5)) - # ) elif ufunc in [ numpy.arccosh, numpy.log, diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 7c953a44c..3006f5d89 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -381,6 +381,35 @@ def test_tracing_astype( assert expected_output == evaluated_output +@pytest.mark.parametrize( + "inputs", + [ + pytest.param( + {"x": EncryptedScalar(Integer(32, is_signed=True))}, + ), + ], +) +@pytest.mark.parametrize( + "function_to_trace", + # We really need a lambda (because numpy functions are not playing + # nice with inspect.signature), but pylint is not happy + # with it + # pylint: disable=unnecessary-lambda + [lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)], + # pylint: enable=unnecessary-lambda +) +def test_trace_numpy_fails_for_invert(inputs, function_to_trace): + """Check we catch calls to numpy.invert and tell user to change their code""" + + with pytest.raises(RuntimeError) as excinfo: + tracing.trace_numpy_function(function_to_trace, inputs) + + assert ( + "NPTracer does not manage the following func: invert. Please replace by calls to " + "bitwise_xor with appropriate mask" in str(excinfo.value) + ) + + @pytest.mark.parametrize( "inputs,expected_output_node", [ @@ -414,10 +443,6 @@ def test_tracing_astype( def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def): """Function to trace supported numpy ufuncs""" - # numpy.invert is expecting inputs which are integer only - if function_to_trace_def == numpy.invert and not isinstance(inputs["x"].dtype, Integer): - return - # We really need a lambda (because numpy functions are not playing # nice with inspect.signature), but pylint and flake8 are not happy # with it