mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: remove support for np.invert
remove support for np.invert and propose to the user to use bitwise_xor instead, because of impossibilities with float fusing closes #658
This commit is contained in:
committed by
Benoit Chevallier
parent
fb0564eea2
commit
7bf2f09615
@@ -110,6 +110,15 @@ class NPTracer(BaseTracer):
|
||||
Callable: the tracing function that needs to be called to trace func
|
||||
"""
|
||||
tracing_func: Optional[Callable]
|
||||
|
||||
# numpy.invert is not great in term of types it supports, so we've decided not to support it
|
||||
# and to propose to the user to use numpy.bitwise_not
|
||||
if func == numpy.invert:
|
||||
raise RuntimeError(
|
||||
f"NPTracer does not manage the following func: {func.__name__}. Please replace by "
|
||||
f"calls to bitwise_xor with appropriate mask"
|
||||
)
|
||||
|
||||
if isinstance(func, numpy.ufunc):
|
||||
tracing_func = NPTracer.UFUNC_ROUTING.get(func, None)
|
||||
else:
|
||||
@@ -266,6 +275,9 @@ class NPTracer(BaseTracer):
|
||||
# numpy.isnat is not there since it is about timings
|
||||
#
|
||||
# numpy.divmod, numpy.modf and numpy.frexp are not there since output two values
|
||||
#
|
||||
# numpy.invert (as known as numpy.bitwise_not) is not here, because it has strange input type.
|
||||
# We ask the user to replace bitwise_xor instead
|
||||
LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [
|
||||
numpy.absolute,
|
||||
numpy.arccos,
|
||||
@@ -301,7 +313,6 @@ class NPTracer(BaseTracer):
|
||||
numpy.greater_equal,
|
||||
numpy.heaviside,
|
||||
numpy.hypot,
|
||||
numpy.invert,
|
||||
numpy.isfinite,
|
||||
numpy.isinf,
|
||||
numpy.isnan,
|
||||
|
||||
@@ -48,7 +48,6 @@ List of supported unary functions:
|
||||
- expm1
|
||||
- fabs
|
||||
- floor
|
||||
- invert
|
||||
- isfinite
|
||||
- isinf
|
||||
- isnan
|
||||
|
||||
@@ -224,10 +224,6 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
|
||||
# Not too large values to avoid overflows
|
||||
input_list = [1, 2, 5, 11]
|
||||
super_fun_list = [mix_x_and_y_and_call_f, mix_x_and_y_intricately_and_call_f]
|
||||
elif fun == numpy.invert:
|
||||
# 0 is not in the domain of definition + expect integer inputs
|
||||
input_list = [1, 2, 42, 44]
|
||||
super_fun_list = [mix_x_and_y_into_integer_and_call_f]
|
||||
else:
|
||||
# Regular case
|
||||
input_list = [0, 2, 42, 44]
|
||||
@@ -307,7 +303,6 @@ LIST_OF_UFUNC_WHICH_HAVE_INTEGER_ONLY_SOURCES = {
|
||||
numpy.bitwise_or,
|
||||
numpy.bitwise_xor,
|
||||
numpy.gcd,
|
||||
numpy.invert,
|
||||
numpy.lcm,
|
||||
numpy.ldexp,
|
||||
numpy.left_shift,
|
||||
|
||||
@@ -393,15 +393,6 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration):
|
||||
((0, 5), (0, 5)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.invert,
|
||||
]:
|
||||
# Can't make it work, to have a fusable function
|
||||
# TODO: fixme
|
||||
pass
|
||||
# subtest_compile_and_run_unary_ufunc_correctness(
|
||||
# ufunc, mix_x_and_y_and_call_f_with_integer_inputs, ((0, 5), (0, 5))
|
||||
# )
|
||||
elif ufunc in [
|
||||
numpy.arccosh,
|
||||
numpy.log,
|
||||
|
||||
@@ -381,6 +381,35 @@ def test_tracing_astype(
|
||||
assert expected_output == evaluated_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace",
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint is not happy
|
||||
# with it
|
||||
# pylint: disable=unnecessary-lambda
|
||||
[lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)],
|
||||
# pylint: enable=unnecessary-lambda
|
||||
)
|
||||
def test_trace_numpy_fails_for_invert(inputs, function_to_trace):
|
||||
"""Check we catch calls to numpy.invert and tell user to change their code"""
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert (
|
||||
"NPTracer does not manage the following func: invert. Please replace by calls to "
|
||||
"bitwise_xor with appropriate mask" in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs,expected_output_node",
|
||||
[
|
||||
@@ -414,10 +443,6 @@ def test_tracing_astype(
|
||||
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
|
||||
# numpy.invert is expecting inputs which are integer only
|
||||
if function_to_trace_def == numpy.invert and not isinstance(inputs["x"].dtype, Integer):
|
||||
return
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
|
||||
Reference in New Issue
Block a user