mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(tracing): add test for extra args passed to ufuncs
- add comment to explain why sanitizing all args is safe for ufuncs
This commit is contained in:
@@ -34,7 +34,7 @@ class NPTracer(BaseTracer):
|
||||
|
||||
_mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype
|
||||
|
||||
def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs):
|
||||
def __array_ufunc__(self, ufunc: numpy.ufunc, method, *args, **kwargs):
|
||||
"""Catch calls to numpy ufunc and routes them to tracing functions if supported.
|
||||
|
||||
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
|
||||
@@ -43,11 +43,13 @@ class NPTracer(BaseTracer):
|
||||
tracing_func = self.get_tracing_func_for_np_function(ufunc)
|
||||
custom_assert(
|
||||
(len(kwargs) == 0),
|
||||
f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc}",
|
||||
f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc.__name__}",
|
||||
)
|
||||
# Create constant tracers when needed
|
||||
sanitized_input_tracers = [self._sanitize(inp) for inp in input_tracers]
|
||||
return tracing_func(*sanitized_input_tracers, **kwargs)
|
||||
|
||||
# Create constant tracers for args, numpy only passes ufunc.nin args so we can
|
||||
# sanitize all of them without issues
|
||||
sanitized_args = [self._sanitize(arg) for arg in args]
|
||||
return tracing_func(*sanitized_args, **kwargs)
|
||||
raise NotImplementedError("Only __call__ method is supported currently")
|
||||
|
||||
def __array_function__(self, func, _types, args, kwargs):
|
||||
|
||||
@@ -462,6 +462,40 @@ def test_trace_numpy_ufuncs_not_supported():
|
||||
assert "Only __call__ method is supported currently" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_trace_numpy_ufuncs_no_kwargs_no_extra_args():
|
||||
"""Test a case where kwargs are not allowed and too many inputs are passed"""
|
||||
inputs = {
|
||||
"x": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
"y": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
"z": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
}
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
# pylint: disable=unnecessary-lambda
|
||||
function_to_trace = lambda x, y, z: numpy.add(x, y, z) # noqa: E731
|
||||
# pylint: enable=unnecessary-lambda
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
# numpy only passes ufunc.nin tracers so the extra arguments are passed as kwargs
|
||||
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
# pylint: disable=unnecessary-lambda
|
||||
function_to_trace = lambda x, y, z: numpy.add(x, y, out=z) # noqa: E731
|
||||
# pylint: enable=unnecessary-lambda
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,inputs,expected_output_node,expected_output_value",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user