mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
185 lines
6.3 KiB
Python
185 lines
6.3 KiB
Python
"""Test file for numpy tracing"""
|
|
|
|
import inspect
|
|
|
|
import numpy
|
|
import pytest
|
|
|
|
from concrete.common.data_types.integers import Integer
|
|
from concrete.common.representation import intermediate as ir
|
|
from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor
|
|
from concrete.numpy import tracing
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"inputs",
|
|
[
|
|
pytest.param(
|
|
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"function_to_trace",
|
|
# We really need a lambda (because numpy functions are not playing
|
|
# nice with inspect.signature), but pylint is not happy
|
|
# with it
|
|
[lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)],
|
|
)
|
|
def test_trace_numpy_fails_for_invert(inputs, function_to_trace):
|
|
"""Check we catch calls to numpy.invert and tell user to change their code"""
|
|
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
tracing.trace_numpy_function(function_to_trace, inputs)
|
|
|
|
assert (
|
|
"NPTracer does not manage the following func: invert. Please replace by calls to "
|
|
"bitwise_xor with appropriate mask" in str(excinfo.value)
|
|
)
|
|
|
|
|
|
def test_trace_numpy_ufuncs_not_supported():
|
|
"""Testing a failure case of trace_numpy_function"""
|
|
inputs = {"x": EncryptedScalar(Integer(128, is_signed=True))}
|
|
|
|
# We really need a lambda (because numpy functions are not playing
|
|
# nice with inspect.signature), but pylint and flake8 are not happy
|
|
# with it
|
|
function_to_trace = lambda x: numpy.add.reduce(x) # noqa: E731
|
|
|
|
with pytest.raises(NotImplementedError) as excinfo:
|
|
tracing.trace_numpy_function(function_to_trace, inputs)
|
|
|
|
assert "Only __call__ method is supported currently" in str(excinfo.value)
|
|
|
|
|
|
def test_trace_numpy_ufuncs_no_kwargs_no_extra_args():
|
|
"""Test a case where kwargs are not allowed and too many inputs are passed"""
|
|
inputs = {
|
|
"x": EncryptedScalar(Integer(32, is_signed=True)),
|
|
"y": EncryptedScalar(Integer(32, is_signed=True)),
|
|
"z": EncryptedScalar(Integer(32, is_signed=True)),
|
|
}
|
|
|
|
# We really need a lambda (because numpy functions are not playing
|
|
# nice with inspect.signature), but pylint and flake8 are not happy
|
|
# with it
|
|
function_to_trace = lambda x, y, z: numpy.add(x, y, z) # noqa: E731
|
|
|
|
with pytest.raises(AssertionError) as excinfo:
|
|
tracing.trace_numpy_function(function_to_trace, inputs)
|
|
|
|
# numpy only passes ufunc.nin tracers so the extra arguments are passed as kwargs
|
|
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
|
|
|
# We really need a lambda (because numpy functions are not playing
|
|
# nice with inspect.signature), but pylint and flake8 are not happy
|
|
# with it
|
|
function_to_trace = lambda x, y, z: numpy.add(x, y, out=z) # noqa: E731
|
|
|
|
with pytest.raises(AssertionError) as excinfo:
|
|
tracing.trace_numpy_function(function_to_trace, inputs)
|
|
|
|
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
|
|
|
|
|
def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
|
|
"""Check NPTracer in case of not-implemented function"""
|
|
with pytest.raises(NotImplementedError) as excinfo:
|
|
tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate)
|
|
|
|
assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"operation,exception_type,match",
|
|
[
|
|
pytest.param(
|
|
lambda x: x + "fail",
|
|
TypeError,
|
|
"unsupported operand type(s) for +: 'NPTracer' and 'str'",
|
|
),
|
|
pytest.param(
|
|
lambda x: "fail" + x,
|
|
TypeError,
|
|
'can only concatenate str (not "NPTracer") to str',
|
|
),
|
|
pytest.param(
|
|
lambda x: x - "fail",
|
|
TypeError,
|
|
"unsupported operand type(s) for -: 'NPTracer' and 'str'",
|
|
),
|
|
pytest.param(
|
|
lambda x: "fail" - x,
|
|
TypeError,
|
|
"unsupported operand type(s) for -: 'str' and 'NPTracer'",
|
|
),
|
|
pytest.param(
|
|
lambda x: x * "fail",
|
|
TypeError,
|
|
"can't multiply sequence by non-int of type 'NPTracer'",
|
|
),
|
|
pytest.param(
|
|
lambda x: "fail" * x,
|
|
TypeError,
|
|
"can't multiply sequence by non-int of type 'NPTracer'",
|
|
),
|
|
pytest.param(
|
|
lambda x: x / "fail",
|
|
TypeError,
|
|
"unsupported operand type(s) for /: 'NPTracer' and 'str'",
|
|
),
|
|
pytest.param(
|
|
lambda x: "fail" / x,
|
|
TypeError,
|
|
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
|
|
),
|
|
pytest.param(
|
|
lambda x: x // "fail",
|
|
TypeError,
|
|
"unsupported operand type(s) for //: 'NPTracer' and 'str'",
|
|
),
|
|
pytest.param(
|
|
lambda x: "fail" // x,
|
|
TypeError,
|
|
"unsupported operand type(s) for //: 'str' and 'NPTracer'",
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
|
|
),
|
|
pytest.param(
|
|
lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv"
|
|
),
|
|
],
|
|
)
|
|
def test_nptracer_unsupported_operands(operation, exception_type, match):
|
|
"""Test cases where NPTracer cannot be used with other operands."""
|
|
tracers = [
|
|
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0)
|
|
for idx, param_name in enumerate(inspect.signature(operation).parameters.keys())
|
|
]
|
|
|
|
with pytest.raises(exception_type) as exc_info:
|
|
_ = operation(*tracers)
|
|
|
|
assert match in str(exc_info)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"lambda_f,params",
|
|
[
|
|
(
|
|
lambda x: numpy.reshape(x, (5, 3)),
|
|
{
|
|
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)),
|
|
},
|
|
),
|
|
],
|
|
)
|
|
def test_errors_with_generic_function(lambda_f, params):
|
|
"Test some errors with generic function"
|
|
with pytest.raises(ValueError) as excinfo:
|
|
tracing.trace_numpy_function(lambda_f, params)
|
|
|
|
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|