feat(tracing): add test for extra args passed to ufuncs

- add comment to explain why sanitizing all args is safe for ufuncs
This commit is contained in:
Arthur Meyre
2021-10-07 10:41:44 +02:00
parent 05e1227269
commit 4a77d0515a
2 changed files with 41 additions and 5 deletions

View File

@@ -34,7 +34,7 @@ class NPTracer(BaseTracer):
_mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype
def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs):
def __array_ufunc__(self, ufunc: numpy.ufunc, method, *args, **kwargs):
"""Catch calls to numpy ufunc and routes them to tracing functions if supported.
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
@@ -43,11 +43,13 @@ class NPTracer(BaseTracer):
tracing_func = self.get_tracing_func_for_np_function(ufunc)
custom_assert(
(len(kwargs) == 0),
f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc}",
f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc.__name__}",
)
# Create constant tracers when needed
sanitized_input_tracers = [self._sanitize(inp) for inp in input_tracers]
return tracing_func(*sanitized_input_tracers, **kwargs)
# Create constant tracers for args, numpy only passes ufunc.nin args so we can
# sanitize all of them without issues
sanitized_args = [self._sanitize(arg) for arg in args]
return tracing_func(*sanitized_args, **kwargs)
raise NotImplementedError("Only __call__ method is supported currently")
def __array_function__(self, func, _types, args, kwargs):

View File

@@ -462,6 +462,40 @@ def test_trace_numpy_ufuncs_not_supported():
assert "Only __call__ method is supported currently" in str(excinfo.value)
def test_trace_numpy_ufuncs_no_kwargs_no_extra_args():
"""Test a case where kwargs are not allowed and too many inputs are passed"""
inputs = {
"x": EncryptedScalar(Integer(32, is_signed=True)),
"y": EncryptedScalar(Integer(32, is_signed=True)),
"z": EncryptedScalar(Integer(32, is_signed=True)),
}
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it
# pylint: disable=unnecessary-lambda
function_to_trace = lambda x, y, z: numpy.add(x, y, z) # noqa: E731
# pylint: enable=unnecessary-lambda
with pytest.raises(AssertionError) as excinfo:
tracing.trace_numpy_function(function_to_trace, inputs)
# numpy only passes ufunc.nin tracers so the extra arguments are passed as kwargs
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it
# pylint: disable=unnecessary-lambda
function_to_trace = lambda x, y, z: numpy.add(x, y, out=z) # noqa: E731
# pylint: enable=unnecessary-lambda
with pytest.raises(AssertionError) as excinfo:
tracing.trace_numpy_function(function_to_trace, inputs)
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
@pytest.mark.parametrize(
"function_to_trace,inputs,expected_output_node,expected_output_value",
[