diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index b855820f1..b554c30ee 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -231,38 +231,125 @@ def test_tracing_astype( @pytest.mark.parametrize( - "inputs,expected_output_node,expected_output_value", + "inputs,expected_output_node", [ pytest.param( {"x": EncryptedScalar(Integer(7, is_signed=False))}, ir.ArbitraryFunction, - EncryptedScalar(Float(64)), ), pytest.param( {"x": EncryptedScalar(Integer(32, is_signed=True))}, ir.ArbitraryFunction, - EncryptedScalar(Float(64)), ), pytest.param( {"x": EncryptedScalar(Integer(64, is_signed=True))}, ir.ArbitraryFunction, - EncryptedScalar(Float(64)), ), pytest.param( {"x": EncryptedScalar(Integer(128, is_signed=True))}, ir.ArbitraryFunction, - None, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), ), pytest.param( {"x": EncryptedScalar(Float(64))}, ir.ArbitraryFunction, - EncryptedScalar(Float(64)), ), ], ) -def test_trace_numpy_supported_ufuncs(inputs, expected_output_node, expected_output_value): +def test_trace_numpy_supported_ufuncs(inputs, expected_output_node): """Function to trace supported numpy ufuncs""" + + LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64: List[numpy.ufunc] = [ + # The commented functions are functions which don't work for the moment, often + # if not always because they require more than a single argument + # numpy.absolute, + # numpy.add, + numpy.arccos, + numpy.arccosh, + numpy.arcsin, + numpy.arcsinh, + numpy.arctan, + # numpy.arctan2, + numpy.arctanh, + # numpy.bitwise_and, + # numpy.bitwise_or, + # numpy.bitwise_xor, + numpy.cbrt, + numpy.ceil, + # numpy.conjugate, + # numpy.copysign, + numpy.cos, + numpy.cosh, + numpy.deg2rad, + numpy.degrees, + # numpy.divmod, + # numpy.equal, + numpy.exp, + numpy.exp2, + numpy.expm1, + numpy.fabs, + # numpy.float_power, + numpy.floor, + # numpy.floor_divide, + # numpy.fmax, + # numpy.fmin, + # numpy.fmod, + # numpy.frexp, + # numpy.gcd, + # numpy.greater, + # numpy.greater_equal, + # numpy.heaviside, + # numpy.hypot, + # numpy.invert, + # numpy.isfinite, + # numpy.isinf, + # numpy.isnan, + # numpy.isnat, + # numpy.lcm, + # numpy.ldexp, + # numpy.left_shift, + # numpy.less, + # numpy.less_equal, + numpy.log, + numpy.log10, + numpy.log1p, + numpy.log2, + # numpy.logaddexp, + # numpy.logaddexp2, + # numpy.logical_and, + # numpy.logical_not, + # numpy.logical_or, + # numpy.logical_xor, + # numpy.matmul, + # numpy.maximum, + # numpy.minimum, + # numpy.modf, + # numpy.multiply, + # numpy.negative, + # numpy.nextafter, + # numpy.not_equal, + # numpy.positive, + # numpy.power, + numpy.rad2deg, + numpy.radians, + # numpy.reciprocal, + # numpy.remainder, + # numpy.right_shift, + numpy.rint, + # numpy.sign, + # numpy.signbit, + numpy.sin, + numpy.sinh, + numpy.spacing, + numpy.sqrt, + # numpy.square, + # numpy.subtract, + numpy.tan, + numpy.tanh, + # numpy.true_divide, + numpy.trunc, + ] + for function_to_trace_def in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC: # We really need a lambda (because numpy functions are not playing @@ -277,7 +364,11 @@ def test_trace_numpy_supported_ufuncs(inputs, expected_output_node, expected_out assert len(op_graph.output_nodes) == 1 assert isinstance(op_graph.output_nodes[0], expected_output_node) assert len(op_graph.output_nodes[0].outputs) == 1 - assert op_graph.output_nodes[0].outputs[0] == expected_output_value + + if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64: + assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64)) + else: + assert op_graph.output_nodes[0].outputs[0] == "to be done" def test_trace_numpy_ufuncs_not_supported():