test: change a bit the way we test

to prepare modifications for more ufunc
refs #126
This commit is contained in:
Benoit Chevallier-Mames
2021-09-29 11:27:21 +02:00
committed by Benoit Chevallier
parent bd95714c23
commit 0a758ed672

View File

@@ -231,38 +231,125 @@ def test_tracing_astype(
@pytest.mark.parametrize(
"inputs,expected_output_node,expected_output_value",
"inputs,expected_output_node",
[
pytest.param(
{"x": EncryptedScalar(Integer(7, is_signed=False))},
ir.ArbitraryFunction,
EncryptedScalar(Float(64)),
),
pytest.param(
{"x": EncryptedScalar(Integer(32, is_signed=True))},
ir.ArbitraryFunction,
EncryptedScalar(Float(64)),
),
pytest.param(
{"x": EncryptedScalar(Integer(64, is_signed=True))},
ir.ArbitraryFunction,
EncryptedScalar(Float(64)),
),
pytest.param(
{"x": EncryptedScalar(Integer(128, is_signed=True))},
ir.ArbitraryFunction,
None,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
pytest.param(
{"x": EncryptedScalar(Float(64))},
ir.ArbitraryFunction,
EncryptedScalar(Float(64)),
),
],
)
def test_trace_numpy_supported_ufuncs(inputs, expected_output_node, expected_output_value):
def test_trace_numpy_supported_ufuncs(inputs, expected_output_node):
"""Function to trace supported numpy ufuncs"""
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64: List[numpy.ufunc] = [
# The commented functions are functions which don't work for the moment, often
# if not always because they require more than a single argument
# numpy.absolute,
# numpy.add,
numpy.arccos,
numpy.arccosh,
numpy.arcsin,
numpy.arcsinh,
numpy.arctan,
# numpy.arctan2,
numpy.arctanh,
# numpy.bitwise_and,
# numpy.bitwise_or,
# numpy.bitwise_xor,
numpy.cbrt,
numpy.ceil,
# numpy.conjugate,
# numpy.copysign,
numpy.cos,
numpy.cosh,
numpy.deg2rad,
numpy.degrees,
# numpy.divmod,
# numpy.equal,
numpy.exp,
numpy.exp2,
numpy.expm1,
numpy.fabs,
# numpy.float_power,
numpy.floor,
# numpy.floor_divide,
# numpy.fmax,
# numpy.fmin,
# numpy.fmod,
# numpy.frexp,
# numpy.gcd,
# numpy.greater,
# numpy.greater_equal,
# numpy.heaviside,
# numpy.hypot,
# numpy.invert,
# numpy.isfinite,
# numpy.isinf,
# numpy.isnan,
# numpy.isnat,
# numpy.lcm,
# numpy.ldexp,
# numpy.left_shift,
# numpy.less,
# numpy.less_equal,
numpy.log,
numpy.log10,
numpy.log1p,
numpy.log2,
# numpy.logaddexp,
# numpy.logaddexp2,
# numpy.logical_and,
# numpy.logical_not,
# numpy.logical_or,
# numpy.logical_xor,
# numpy.matmul,
# numpy.maximum,
# numpy.minimum,
# numpy.modf,
# numpy.multiply,
# numpy.negative,
# numpy.nextafter,
# numpy.not_equal,
# numpy.positive,
# numpy.power,
numpy.rad2deg,
numpy.radians,
# numpy.reciprocal,
# numpy.remainder,
# numpy.right_shift,
numpy.rint,
# numpy.sign,
# numpy.signbit,
numpy.sin,
numpy.sinh,
numpy.spacing,
numpy.sqrt,
# numpy.square,
# numpy.subtract,
numpy.tan,
numpy.tanh,
# numpy.true_divide,
numpy.trunc,
]
for function_to_trace_def in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
# We really need a lambda (because numpy functions are not playing
@@ -277,7 +364,11 @@ def test_trace_numpy_supported_ufuncs(inputs, expected_output_node, expected_out
assert len(op_graph.output_nodes) == 1
assert isinstance(op_graph.output_nodes[0], expected_output_node)
assert len(op_graph.output_nodes[0].outputs) == 1
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64:
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64))
else:
assert op_graph.output_nodes[0].outputs[0] == "to be done"
def test_trace_numpy_ufuncs_not_supported():