diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index bb31aa746..4776278fb 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -1,7 +1,7 @@ """hnumpy tracing utilities.""" from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy from numpy.typing import DTypeLike @@ -43,7 +43,7 @@ class NPTracer(BaseTracer): assert ( len(kwargs) == 0 ), f"hnumpy does not support **kwargs currently for numpy ufuncs, ufunc: {ufunc}" - return tracing_func(self, *input_tracers, **kwargs) + return tracing_func(*input_tracers, **kwargs) raise NotImplementedError("Only __call__ method is supported currently") def __array_function__(self, func, _types, args, kwargs): @@ -130,8 +130,9 @@ class NPTracer(BaseTracer): ] return common_output_dtypes + @classmethod def _unary_operator( - self, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs + cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs ) -> "NPTracer": """Function to trace an unary operator. @@ -139,7 +140,7 @@ class NPTracer(BaseTracer): NPTracer: The output NPTracer containing the traced function """ assert len(input_tracers) == 1 - common_output_dtypes = self._manage_dtypes(unary_operator, *input_tracers) + common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers) assert len(common_output_dtypes) == 1 traced_computation = ArbitraryFunction( @@ -149,93 +150,13 @@ class NPTracer(BaseTracer): op_kwargs=deepcopy(kwargs), op_name=unary_operator_string, ) - output_tracer = self.__class__( + output_tracer = cls( input_tracers, traced_computation=traced_computation, output_index=0, ) return output_tracer - def rint(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.rint. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.rint, "np.rint", *input_tracers, **kwargs) - - def sin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.sin. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.sin, "np.sin", *input_tracers, **kwargs) - - def cos(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.cos. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.cos, "np.cos", *input_tracers, **kwargs) - - def tan(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.tan. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.tan, "np.tan", *input_tracers, **kwargs) - - def arcsin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.arcsin. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.arcsin, "np.arcsin", *input_tracers, **kwargs) - - def arccos(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.arccos. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.arccos, "np.arccos", *input_tracers, **kwargs) - - def arctan(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.arctan. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.arctan, "np.arctan", *input_tracers, **kwargs) - - def exp(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.exp. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.exp, "np.exp", *input_tracers, **kwargs) - - def expm1(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.expm1. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.expm1, "np.expm1", *input_tracers, **kwargs) - - def exp2(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.exp2. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - return self._unary_operator(numpy.exp2, "np.exp2", *input_tracers, **kwargs) - def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer": """Function to trace numpy.dot. @@ -261,24 +182,124 @@ class NPTracer(BaseTracer): ) return output_tracer - UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = { - numpy.rint: rint, - numpy.sin: sin, - numpy.cos: cos, - numpy.tan: tan, - numpy.arcsin: arcsin, - numpy.arccos: arccos, - numpy.arctan: arctan, - numpy.exp: exp, - numpy.expm1: expm1, - numpy.exp2: exp2, - } + LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [ + # The commented functions are functions which don't work for the moment, often + # if not always because they require more than a single argument + # numpy.absolute, + # numpy.add, + numpy.arccos, + numpy.arccosh, + numpy.arcsin, + numpy.arcsinh, + numpy.arctan, + # numpy.arctan2, + numpy.arctanh, + # numpy.bitwise_and, + # numpy.bitwise_or, + # numpy.bitwise_xor, + numpy.cbrt, + numpy.ceil, + # numpy.conjugate, + # numpy.copysign, + numpy.cos, + numpy.cosh, + numpy.deg2rad, + numpy.degrees, + # numpy.divmod, + # numpy.equal, + numpy.exp, + numpy.exp2, + numpy.expm1, + numpy.fabs, + # numpy.float_power, + numpy.floor, + # numpy.floor_divide, + # numpy.fmax, + # numpy.fmin, + # numpy.fmod, + # numpy.frexp, + # numpy.gcd, + # numpy.greater, + # numpy.greater_equal, + # numpy.heaviside, + # numpy.hypot, + # numpy.invert, + # numpy.isfinite, + # numpy.isinf, + # numpy.isnan, + # numpy.isnat, + # numpy.lcm, + # numpy.ldexp, + # numpy.left_shift, + # numpy.less, + # numpy.less_equal, + numpy.log, + numpy.log10, + numpy.log1p, + numpy.log2, + # numpy.logaddexp, + # numpy.logaddexp2, + # numpy.logical_and, + # numpy.logical_not, + # numpy.logical_or, + # numpy.logical_xor, + # numpy.matmul, + # numpy.maximum, + # numpy.minimum, + # numpy.modf, + # numpy.multiply, + # numpy.negative, + # numpy.nextafter, + # numpy.not_equal, + # numpy.positive, + # numpy.power, + numpy.rad2deg, + numpy.radians, + # numpy.reciprocal, + # numpy.remainder, + # numpy.right_shift, + numpy.rint, + # numpy.sign, + # numpy.signbit, + numpy.sin, + numpy.sinh, + numpy.spacing, + numpy.sqrt, + # numpy.square, + # numpy.subtract, + numpy.tan, + numpy.tanh, + # numpy.true_divide, + numpy.trunc, + ] + + # We build UFUNC_ROUTING dynamically after the creation of the class, + # because of some limits of python or our unability to do it properly + # in the class with techniques which are compatible with the different + # coding checks we use + UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {} FUNC_ROUTING: Dict[Callable, Callable] = { numpy.dot: dot, } +def _get_fun(function: numpy.ufunc): + """Helper function to wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING.""" + + # We have to access this method to be able to build NPTracer.UFUNC_ROUTING + # dynamically + # pylint: disable=protected-access + return lambda *input_tracers, **kwargs: NPTracer._unary_operator( + function, f"np.{function.__name__}", *input_tracers, **kwargs + ) + # pylint: enable=protected-access + + +# We are populating NPTracer.UFUNC_ROUTING dynamically +NPTracer.UFUNC_ROUTING = {fun: _get_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC} + + def trace_numpy_function( function_to_trace: Callable, function_parameters: Dict[str, BaseValue] ) -> OPGraph: diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index 32d8e822f..a16223260 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -230,31 +230,6 @@ def test_tracing_astype( assert expected_output == evaluated_output -@pytest.mark.parametrize( - "function_to_trace", - [ - # We cannot call trace_numpy_function on some numpy function as getting the signature for - # these functions fails, so we wrap it in a lambda - # pylint: disable=unnecessary-lambda - pytest.param(lambda x: numpy.rint(x)), - pytest.param(lambda x: numpy.sin(x)), - pytest.param(lambda x: numpy.cos(x)), - pytest.param(lambda x: numpy.tan(x)), - pytest.param(lambda x: numpy.arcsin(x)), - pytest.param(lambda x: numpy.arccos(x)), - pytest.param(lambda x: numpy.arctan(x)), - pytest.param(lambda x: numpy.exp(x)), - pytest.param(lambda x: numpy.expm1(x)), - pytest.param(lambda x: numpy.exp2(x)), - # The next test case is only for coverage purposes, to trigger the unsupported method - # exception handling - pytest.param( - lambda x: numpy.add.reduce(x), - marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), - ), - # pylint: enable=unnecessary-lambda - ], -) @pytest.mark.parametrize( "inputs,expected_output_node,expected_output_value", [ @@ -286,16 +261,40 @@ def test_tracing_astype( ), ], ) -def test_trace_hnumpy_supported_ufuncs( - function_to_trace, inputs, expected_output_node, expected_output_value -): +def test_trace_hnumpy_supported_ufuncs(inputs, expected_output_node, expected_output_value): """Function to trace supported numpy ufuncs""" - op_graph = tracing.trace_numpy_function(function_to_trace, inputs) + for function_to_trace_def in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC: - assert len(op_graph.output_nodes) == 1 - assert isinstance(op_graph.output_nodes[0], expected_output_node) - assert len(op_graph.output_nodes[0].outputs) == 1 - assert op_graph.output_nodes[0].outputs[0] == expected_output_value + # We really need a lambda (because numpy functions are not playing + # nice with inspect.signature), but pylint and flake8 are not happy + # with it + # pylint: disable=unnecessary-lambda,cell-var-from-loop + function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731 + # pylint: enable=unnecessary-lambda,cell-var-from-loop + + op_graph = tracing.trace_numpy_function(function_to_trace, inputs) + + assert len(op_graph.output_nodes) == 1 + assert isinstance(op_graph.output_nodes[0], expected_output_node) + assert len(op_graph.output_nodes[0].outputs) == 1 + assert op_graph.output_nodes[0].outputs[0] == expected_output_value + + +def test_trace_hnumpy_ufuncs_not_supported(): + """Testing a failure case of trace_numpy_function""" + inputs = {"x": EncryptedScalar(Integer(128, is_signed=True))} + + # We really need a lambda (because numpy functions are not playing + # nice with inspect.signature), but pylint and flake8 are not happy + # with it + # pylint: disable=unnecessary-lambda + function_to_trace = lambda x: numpy.add.reduce(x) # noqa: E731 + # pylint: enable=unnecessary-lambda + + with pytest.raises(NotImplementedError) as excinfo: + tracing.trace_numpy_function(function_to_trace, inputs) + + assert "Only __call__ method is supported currently" in str(excinfo.value) @pytest.mark.parametrize( @@ -351,31 +350,23 @@ def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expec assert op_graph.output_nodes[0].outputs[0] == expected_output_value -@pytest.mark.parametrize( - "np_function,expected_tracing_func", - [ - pytest.param(numpy.rint, tracing.NPTracer.rint), - pytest.param(numpy.sin, tracing.NPTracer.sin), - pytest.param(numpy.cos, tracing.NPTracer.cos), - pytest.param(numpy.tan, tracing.NPTracer.tan), - pytest.param(numpy.arcsin, tracing.NPTracer.arcsin), - pytest.param(numpy.arccos, tracing.NPTracer.arccos), - pytest.param(numpy.arctan, tracing.NPTracer.arctan), - pytest.param(numpy.exp, tracing.NPTracer.exp), - pytest.param(numpy.expm1, tracing.NPTracer.expm1), - pytest.param(numpy.exp2, tracing.NPTracer.exp2), - pytest.param(numpy.dot, tracing.NPTracer.dot), - # There is a need to test the case where the function fails, I chose numpy.conjugate which - # works on complex types, as we don't talk about complex types for now this looks like a - # good long term candidate to check for an unsupported function - pytest.param( - numpy.conjugate, None, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError) - ), - ], -) -def test_nptracer_get_tracing_func_for_np_functions(np_function, expected_tracing_func): +def test_nptracer_get_tracing_func_for_np_functions(): """Test NPTracer get_tracing_func_for_np_function""" - assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func + + for np_function in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC: + expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function] + + assert ( + tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func + ) + + +def test_nptracer_get_tracing_func_for_np_functions_not_implemented(): + """Check NPTracer in case of not-implemented function""" + with pytest.raises(NotImplementedError) as excinfo: + tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate) + + assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value) @pytest.mark.parametrize(