mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
test: change a bit the way we test
to prepare modifications for more ufunc refs #126
This commit is contained in:
committed by
Benoit Chevallier
parent
bd95714c23
commit
0a758ed672
@@ -231,38 +231,125 @@ def test_tracing_astype(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs,expected_output_node,expected_output_value",
|
||||
"inputs,expected_output_node",
|
||||
[
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(7, is_signed=False))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedScalar(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(32, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedScalar(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(64, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedScalar(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Integer(128, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
None,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
pytest.param(
|
||||
{"x": EncryptedScalar(Float(64))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedScalar(Float(64)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_trace_numpy_supported_ufuncs(inputs, expected_output_node, expected_output_value):
|
||||
def test_trace_numpy_supported_ufuncs(inputs, expected_output_node):
|
||||
"""Function to trace supported numpy ufuncs"""
|
||||
|
||||
LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64: List[numpy.ufunc] = [
|
||||
# The commented functions are functions which don't work for the moment, often
|
||||
# if not always because they require more than a single argument
|
||||
# numpy.absolute,
|
||||
# numpy.add,
|
||||
numpy.arccos,
|
||||
numpy.arccosh,
|
||||
numpy.arcsin,
|
||||
numpy.arcsinh,
|
||||
numpy.arctan,
|
||||
# numpy.arctan2,
|
||||
numpy.arctanh,
|
||||
# numpy.bitwise_and,
|
||||
# numpy.bitwise_or,
|
||||
# numpy.bitwise_xor,
|
||||
numpy.cbrt,
|
||||
numpy.ceil,
|
||||
# numpy.conjugate,
|
||||
# numpy.copysign,
|
||||
numpy.cos,
|
||||
numpy.cosh,
|
||||
numpy.deg2rad,
|
||||
numpy.degrees,
|
||||
# numpy.divmod,
|
||||
# numpy.equal,
|
||||
numpy.exp,
|
||||
numpy.exp2,
|
||||
numpy.expm1,
|
||||
numpy.fabs,
|
||||
# numpy.float_power,
|
||||
numpy.floor,
|
||||
# numpy.floor_divide,
|
||||
# numpy.fmax,
|
||||
# numpy.fmin,
|
||||
# numpy.fmod,
|
||||
# numpy.frexp,
|
||||
# numpy.gcd,
|
||||
# numpy.greater,
|
||||
# numpy.greater_equal,
|
||||
# numpy.heaviside,
|
||||
# numpy.hypot,
|
||||
# numpy.invert,
|
||||
# numpy.isfinite,
|
||||
# numpy.isinf,
|
||||
# numpy.isnan,
|
||||
# numpy.isnat,
|
||||
# numpy.lcm,
|
||||
# numpy.ldexp,
|
||||
# numpy.left_shift,
|
||||
# numpy.less,
|
||||
# numpy.less_equal,
|
||||
numpy.log,
|
||||
numpy.log10,
|
||||
numpy.log1p,
|
||||
numpy.log2,
|
||||
# numpy.logaddexp,
|
||||
# numpy.logaddexp2,
|
||||
# numpy.logical_and,
|
||||
# numpy.logical_not,
|
||||
# numpy.logical_or,
|
||||
# numpy.logical_xor,
|
||||
# numpy.matmul,
|
||||
# numpy.maximum,
|
||||
# numpy.minimum,
|
||||
# numpy.modf,
|
||||
# numpy.multiply,
|
||||
# numpy.negative,
|
||||
# numpy.nextafter,
|
||||
# numpy.not_equal,
|
||||
# numpy.positive,
|
||||
# numpy.power,
|
||||
numpy.rad2deg,
|
||||
numpy.radians,
|
||||
# numpy.reciprocal,
|
||||
# numpy.remainder,
|
||||
# numpy.right_shift,
|
||||
numpy.rint,
|
||||
# numpy.sign,
|
||||
# numpy.signbit,
|
||||
numpy.sin,
|
||||
numpy.sinh,
|
||||
numpy.spacing,
|
||||
numpy.sqrt,
|
||||
# numpy.square,
|
||||
# numpy.subtract,
|
||||
numpy.tan,
|
||||
numpy.tanh,
|
||||
# numpy.true_divide,
|
||||
numpy.trunc,
|
||||
]
|
||||
|
||||
for function_to_trace_def in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
|
||||
|
||||
# We really need a lambda (because numpy functions are not playing
|
||||
@@ -277,7 +364,11 @@ def test_trace_numpy_supported_ufuncs(inputs, expected_output_node, expected_out
|
||||
assert len(op_graph.output_nodes) == 1
|
||||
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
||||
assert len(op_graph.output_nodes[0].outputs) == 1
|
||||
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
|
||||
|
||||
if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64:
|
||||
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64))
|
||||
else:
|
||||
assert op_graph.output_nodes[0].outputs[0] == "to be done"
|
||||
|
||||
|
||||
def test_trace_numpy_ufuncs_not_supported():
|
||||
|
||||
Reference in New Issue
Block a user