From 4a77d0515a92c03ba36a778610d1e858dce8533e Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 7 Oct 2021 10:41:44 +0200 Subject: [PATCH] feat(tracing): add test for extra args passed to ufuncs - add comment to explain why sanitizing all args is safe for ufuncs --- concrete/numpy/tracing.py | 12 +++++++----- tests/numpy/test_tracing.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index fbdc330a9..4af5fd446 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -34,7 +34,7 @@ class NPTracer(BaseTracer): _mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype - def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs): + def __array_ufunc__(self, ufunc: numpy.ufunc, method, *args, **kwargs): """Catch calls to numpy ufunc and routes them to tracing functions if supported. Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch @@ -43,11 +43,13 @@ class NPTracer(BaseTracer): tracing_func = self.get_tracing_func_for_np_function(ufunc) custom_assert( (len(kwargs) == 0), - f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc}", + f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc.__name__}", ) - # Create constant tracers when needed - sanitized_input_tracers = [self._sanitize(inp) for inp in input_tracers] - return tracing_func(*sanitized_input_tracers, **kwargs) + + # Create constant tracers for args, numpy only passes ufunc.nin args so we can + # sanitize all of them without issues + sanitized_args = [self._sanitize(arg) for arg in args] + return tracing_func(*sanitized_args, **kwargs) raise NotImplementedError("Only __call__ method is supported currently") def __array_function__(self, func, _types, args, kwargs): diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 6a9998d95..dffa405fc 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -462,6 +462,40 @@ def test_trace_numpy_ufuncs_not_supported(): assert "Only __call__ method is supported currently" in str(excinfo.value) +def test_trace_numpy_ufuncs_no_kwargs_no_extra_args(): + """Test a case where kwargs are not allowed and too many inputs are passed""" + inputs = { + "x": EncryptedScalar(Integer(32, is_signed=True)), + "y": EncryptedScalar(Integer(32, is_signed=True)), + "z": EncryptedScalar(Integer(32, is_signed=True)), + } + + # We really need a lambda (because numpy functions are not playing + # nice with inspect.signature), but pylint and flake8 are not happy + # with it + # pylint: disable=unnecessary-lambda + function_to_trace = lambda x, y, z: numpy.add(x, y, z) # noqa: E731 + # pylint: enable=unnecessary-lambda + + with pytest.raises(AssertionError) as excinfo: + tracing.trace_numpy_function(function_to_trace, inputs) + + # numpy only passes ufunc.nin tracers so the extra arguments are passed as kwargs + assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value) + + # We really need a lambda (because numpy functions are not playing + # nice with inspect.signature), but pylint and flake8 are not happy + # with it + # pylint: disable=unnecessary-lambda + function_to_trace = lambda x, y, z: numpy.add(x, y, out=z) # noqa: E731 + # pylint: enable=unnecessary-lambda + + with pytest.raises(AssertionError) as excinfo: + tracing.trace_numpy_function(function_to_trace, inputs) + + assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value) + + @pytest.mark.parametrize( "function_to_trace,inputs,expected_output_node,expected_output_value", [