mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
committed by
Benoit Chevallier
parent
55a39d4c26
commit
5448669e83
@@ -1,9 +1,6 @@
|
||||
"""Test file for numpy tracing"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
|
||||
import networkx as nx
|
||||
import numpy
|
||||
@@ -392,144 +389,6 @@ def test_tracing_astype_single_element_array_corner_case():
|
||||
assert numpy.array_equal(numpy.array([1], dtype=numpy.int32), eval_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace",
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint is not happy
|
||||
# with it
|
||||
[lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)],
|
||||
)
|
||||
def test_trace_numpy_fails_for_invert(inputs, function_to_trace):
|
||||
"""Check we catch calls to numpy.invert and tell user to change their code"""
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert (
|
||||
"NPTracer does not manage the following func: invert. Please replace by calls to "
|
||||
"bitwise_xor with appropriate mask" in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs,expected_output_node",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(7, is_signed=False))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(64, is_signed=True))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(128, is_signed=True))},
|
||||
ir.GenericFunction,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Float(64))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace_def",
|
||||
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1],
|
||||
)
|
||||
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
# pylint: disable=cell-var-from-loop
|
||||
function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
|
||||
if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64:
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64))
|
||||
elif function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL:
|
||||
|
||||
# Boolean function
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Integer(8, is_signed=False))
|
||||
else:
|
||||
|
||||
# Function keeping more or less input type
|
||||
input_node_type = inputs["x"]
|
||||
|
||||
expected_output_node_type = deepcopy(input_node_type)
|
||||
|
||||
expected_output_node_type.dtype.bit_width = max(
|
||||
expected_output_node_type.dtype.bit_width, 32
|
||||
)
|
||||
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_node_type
|
||||
|
||||
|
||||
def test_trace_numpy_ufuncs_not_supported():
|
||||
"""Testing a failure case of trace_numpy_function"""
|
||||
inputs = {"x": EncryptedScalar(Integer(128, is_signed=True))}
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
function_to_trace = lambda x: numpy.add.reduce(x) # noqa: E731
|
||||
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert "Only __call__ method is supported currently" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_trace_numpy_ufuncs_no_kwargs_no_extra_args():
|
||||
"""Test a case where kwargs are not allowed and too many inputs are passed"""
|
||||
inputs = {
|
||||
"x": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
"y": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
"z": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
}
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
function_to_trace = lambda x, y, z: numpy.add(x, y, z) # noqa: E731
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
# numpy only passes ufunc.nin tracers so the extra arguments are passed as kwargs
|
||||
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
function_to_trace = lambda x, y, z: numpy.add(x, y, out=z) # noqa: E731
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,inputs,expected_output_node,expected_output_value",
|
||||
[
|
||||
@@ -934,6 +793,3 @@ def test_errors_with_generic_function(lambda_f, params):
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
|
||||
|
||||
# pylint: enable=too-many-lines
|
||||
|
||||
302
tests/numpy/test_tracing_calls.py
Normal file
302
tests/numpy/test_tracing_calls.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""Test file for numpy tracing"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
|
||||
OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul]
|
||||
|
||||
# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output
|
||||
# is a float64, whatever the input type
|
||||
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64 = set(
|
||||
[
|
||||
numpy.arccos,
|
||||
numpy.arccosh,
|
||||
numpy.arcsin,
|
||||
numpy.arcsinh,
|
||||
numpy.arctan,
|
||||
numpy.arctanh,
|
||||
numpy.cbrt,
|
||||
numpy.ceil,
|
||||
numpy.cos,
|
||||
numpy.cosh,
|
||||
numpy.deg2rad,
|
||||
numpy.degrees,
|
||||
numpy.exp,
|
||||
numpy.exp2,
|
||||
numpy.expm1,
|
||||
numpy.fabs,
|
||||
numpy.floor,
|
||||
numpy.log,
|
||||
numpy.log10,
|
||||
numpy.log1p,
|
||||
numpy.log2,
|
||||
numpy.rad2deg,
|
||||
numpy.radians,
|
||||
numpy.rint,
|
||||
numpy.sin,
|
||||
numpy.sinh,
|
||||
numpy.spacing,
|
||||
numpy.sqrt,
|
||||
numpy.tan,
|
||||
numpy.tanh,
|
||||
numpy.trunc,
|
||||
]
|
||||
)
|
||||
|
||||
# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output
|
||||
# is a boolean, whatever the input type
|
||||
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL = set(
|
||||
[
|
||||
numpy.isfinite,
|
||||
numpy.isinf,
|
||||
numpy.isnan,
|
||||
numpy.signbit,
|
||||
numpy.logical_not,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs,expected_output_node",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(7, is_signed=False))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(64, is_signed=True))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(128, is_signed=True))},
|
||||
ir.GenericFunction,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Float(64))},
|
||||
ir.GenericFunction,
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace_def",
|
||||
[f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1],
|
||||
)
|
||||
def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
# pylint: disable=cell-var-from-loop
|
||||
function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
|
||||
if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64:
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64))
|
||||
elif function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL:
|
||||
|
||||
# Boolean function
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Integer(8, is_signed=False))
|
||||
else:
|
||||
|
||||
# Function keeping more or less input type
|
||||
input_node_type = inputs["x"]
|
||||
|
||||
expected_output_node_type = deepcopy(input_node_type)
|
||||
|
||||
expected_output_node_type.dtype.bit_width = max(
|
||||
expected_output_node_type.dtype.bit_width, 32
|
||||
)
|
||||
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_node_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize("np_function", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC)
|
||||
def test_nptracer_get_tracing_func_for_np_functions(np_function):
|
||||
"""Test NPTracer get_tracing_func_for_np_function"""
|
||||
|
||||
expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function]
|
||||
|
||||
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
|
||||
|
||||
|
||||
def subtest_tracing_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
|
||||
assert numpy.array_equal(expected_output, evaluated_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value_input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.transpose(x),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([[0, 2], [1, 3]]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4, 8).reshape(2, 2),
|
||||
numpy.array([[4, 6], [5, 7]]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(42),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.transpose(x) + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(3, 5).transpose(),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(84),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.ravel(x),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.array([42], dtype=numpy.int64),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)) + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(5, 3),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_numpy_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,input_value_input_and_expected_output_tuples",
|
||||
[
|
||||
(
|
||||
lambda x: x.transpose() + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(3, 5).transpose(),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.int64(84),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.ravel(),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)),
|
||||
numpy.arange(4).reshape(2, 2),
|
||||
numpy.array([0, 1, 2, 3]),
|
||||
),
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
numpy.array([42], dtype=numpy.int64),
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
lambda x: x.reshape((5, 3)) + 42,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(42, 57).reshape(5, 3),
|
||||
),
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 3)),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(6, is_signed=False), shape=()),
|
||||
numpy.int64(42),
|
||||
None,
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_ndarray_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
|
||||
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
|
||||
236
tests/numpy/test_tracing_failures.py
Normal file
236
tests/numpy/test_tracing_failures.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Test file for numpy tracing"""
|
||||
|
||||
import inspect
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
|
||||
OPERATIONS_TO_TEST = [ir.Add, ir.Sub, ir.Mul]
|
||||
|
||||
# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output
|
||||
# is a float64, whatever the input type
|
||||
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64 = set(
|
||||
[
|
||||
numpy.arccos,
|
||||
numpy.arccosh,
|
||||
numpy.arcsin,
|
||||
numpy.arcsinh,
|
||||
numpy.arctan,
|
||||
numpy.arctanh,
|
||||
numpy.cbrt,
|
||||
numpy.ceil,
|
||||
numpy.cos,
|
||||
numpy.cosh,
|
||||
numpy.deg2rad,
|
||||
numpy.degrees,
|
||||
numpy.exp,
|
||||
numpy.exp2,
|
||||
numpy.expm1,
|
||||
numpy.fabs,
|
||||
numpy.floor,
|
||||
numpy.log,
|
||||
numpy.log10,
|
||||
numpy.log1p,
|
||||
numpy.log2,
|
||||
numpy.rad2deg,
|
||||
numpy.radians,
|
||||
numpy.rint,
|
||||
numpy.sin,
|
||||
numpy.sinh,
|
||||
numpy.spacing,
|
||||
numpy.sqrt,
|
||||
numpy.tan,
|
||||
numpy.tanh,
|
||||
numpy.trunc,
|
||||
]
|
||||
)
|
||||
|
||||
# Functions from tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC, whose output
|
||||
# is a boolean, whatever the input type
|
||||
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL = set(
|
||||
[
|
||||
numpy.isfinite,
|
||||
numpy.isinf,
|
||||
numpy.isnan,
|
||||
numpy.signbit,
|
||||
numpy.logical_not,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace",
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint is not happy
|
||||
# with it
|
||||
[lambda x: numpy.invert(x), lambda x: numpy.bitwise_not(x)],
|
||||
)
|
||||
def test_trace_numpy_fails_for_invert(inputs, function_to_trace):
|
||||
"""Check we catch calls to numpy.invert and tell user to change their code"""
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert (
|
||||
"NPTracer does not manage the following func: invert. Please replace by calls to "
|
||||
"bitwise_xor with appropriate mask" in str(excinfo.value)
|
||||
)
|
||||
|
||||
|
||||
def test_trace_numpy_ufuncs_not_supported():
|
||||
"""Testing a failure case of trace_numpy_function"""
|
||||
inputs = {"x": EncryptedScalar(Integer(128, is_signed=True))}
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
function_to_trace = lambda x: numpy.add.reduce(x) # noqa: E731
|
||||
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert "Only __call__ method is supported currently" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_trace_numpy_ufuncs_no_kwargs_no_extra_args():
|
||||
"""Test a case where kwargs are not allowed and too many inputs are passed"""
|
||||
inputs = {
|
||||
"x": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
"y": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
"z": EncryptedScalar(Integer(32, is_signed=True)),
|
||||
}
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
function_to_trace = lambda x, y, z: numpy.add(x, y, z) # noqa: E731
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
# numpy only passes ufunc.nin tracers so the extra arguments are passed as kwargs
|
||||
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
# nice with inspect.signature), but pylint and flake8 are not happy
|
||||
# with it
|
||||
function_to_trace = lambda x, y, z: numpy.add(x, y, out=z) # noqa: E731
|
||||
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(function_to_trace, inputs)
|
||||
|
||||
assert "**kwargs are currently not supported for numpy ufuncs, ufunc: add" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
|
||||
"""Check NPTracer in case of not-implemented function"""
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate)
|
||||
|
||||
assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation,exception_type,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x + "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for +: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" + x,
|
||||
TypeError,
|
||||
'can only concatenate str (not "NPTracer") to str',
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for -: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" - x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for -: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * "fail",
|
||||
TypeError,
|
||||
"can't multiply sequence by non-int of type 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" * x,
|
||||
TypeError,
|
||||
"can't multiply sequence by non-int of type 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x / "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for /: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" / x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for /: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x // "fail",
|
||||
TypeError,
|
||||
"unsupported operand type(s) for //: 'NPTracer' and 'str'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: "fail" // x,
|
||||
TypeError,
|
||||
"unsupported operand type(s) for //: 'str' and 'NPTracer'",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x / y, NotImplementedError, "Can't manage binary operator truediv"
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x // y, NotImplementedError, "Can't manage binary operator floordiv"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
"""Test cases where NPTracer cannot be used with other operands."""
|
||||
tracers = [
|
||||
tracing.NPTracer([], ir.Input(ClearScalar(Integer(32, True)), param_name, idx), 0)
|
||||
for idx, param_name in enumerate(inspect.signature(operation).parameters.keys())
|
||||
]
|
||||
|
||||
with pytest.raises(exception_type) as exc_info:
|
||||
_ = operation(*tracers)
|
||||
|
||||
assert match in str(exc_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_errors_with_generic_function(lambda_f, params):
|
||||
"Test some errors with generic function"
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
Reference in New Issue
Block a user