diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index fbd11b881..9ccf24b5b 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -130,22 +130,24 @@ class NPTracer(BaseTracer): ] return common_output_dtypes - def rint(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": - """Function to trace numpy.rint. + def _unary_operator( + self, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs + ) -> "NPTracer": + """Function to trace an unary operator. Returns: NPTracer: The output NPTracer containing the traced function """ assert len(input_tracers) == 1 - common_output_dtypes = self._manage_dtypes(numpy.rint, *input_tracers) + common_output_dtypes = self._manage_dtypes(unary_operator, *input_tracers) assert len(common_output_dtypes) == 1 traced_computation = ArbitraryFunction( input_base_value=input_tracers[0].output, - arbitrary_func=numpy.rint, + arbitrary_func=unary_operator, output_dtype=common_output_dtypes[0], op_kwargs=deepcopy(kwargs), - op_name="np.rint", + op_name=unary_operator_string, ) output_tracer = self.__class__( input_tracers, @@ -154,29 +156,29 @@ class NPTracer(BaseTracer): ) return output_tracer + def rint(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": + """Function to trace numpy.rint. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + return self._unary_operator(numpy.rint, "np.rint", *input_tracers, **kwargs) + def sin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": """Function to trace numpy.sin. Returns: NPTracer: The output NPTracer containing the traced function """ - assert len(input_tracers) == 1 - common_output_dtypes = self._manage_dtypes(numpy.sin, *input_tracers) - assert len(common_output_dtypes) == 1 + return self._unary_operator(numpy.sin, "np.sin", *input_tracers, **kwargs) - traced_computation = ArbitraryFunction( - input_base_value=input_tracers[0].output, - arbitrary_func=numpy.sin, - output_dtype=common_output_dtypes[0], - op_kwargs=deepcopy(kwargs), - op_name="np.sin", - ) - output_tracer = self.__class__( - input_tracers, - traced_computation=traced_computation, - output_index=0, - ) - return output_tracer + def cos(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": + """Function to trace numpy.cos. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + return self._unary_operator(numpy.cos, "np.cos", *input_tracers, **kwargs) def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer": """Function to trace numpy.dot. @@ -206,6 +208,7 @@ class NPTracer(BaseTracer): UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = { numpy.rint: rint, numpy.sin: sin, + numpy.cos: cos, } FUNC_ROUTING: Dict[Callable, Callable] = { diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index 5e8a76ff6..6fd5de1a8 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -238,6 +238,7 @@ def test_tracing_astype( # pylint: disable=unnecessary-lambda pytest.param(lambda x: numpy.rint(x)), pytest.param(lambda x: numpy.sin(x)), + pytest.param(lambda x: numpy.cos(x)), # The next test case is only for coverage purposes, to trigger the unsupported method # exception handling pytest.param( @@ -348,6 +349,7 @@ def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expec [ pytest.param(numpy.rint, tracing.NPTracer.rint), pytest.param(numpy.sin, tracing.NPTracer.sin), + pytest.param(numpy.cos, tracing.NPTracer.cos), pytest.param(numpy.dot, tracing.NPTracer.dot), # There is a need to test the case where the function fails, I chose numpy.conjugate which # works on complex types, as we don't talk about complex types for now this looks like a