diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index 9fd8b86d1..a5ff23013 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -212,6 +212,30 @@ class NPTracer(BaseTracer): """ return self._unary_operator(numpy.arctan, "np.arctan", *input_tracers, **kwargs) + def exp(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": + """Function to trace numpy.exp. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + return self._unary_operator(numpy.exp, "np.exp", *input_tracers, **kwargs) + + def expm1(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": + """Function to trace numpy.expm1. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + return self._unary_operator(numpy.expm1, "np.expm1", *input_tracers, **kwargs) + + def exp2(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": + """Function to trace numpy.exp2. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + return self._unary_operator(numpy.exp2, "np.exp2", *input_tracers, **kwargs) + def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer": """Function to trace numpy.dot. @@ -245,6 +269,9 @@ class NPTracer(BaseTracer): numpy.arcsin: arcsin, numpy.arccos: arccos, numpy.arctan: arctan, + numpy.exp: exp, + numpy.expm1: expm1, + numpy.exp2: exp2, } FUNC_ROUTING: Dict[Callable, Callable] = { diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index cb8b96985..7bc4a4538 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -243,6 +243,9 @@ def test_tracing_astype( pytest.param(lambda x: numpy.arcsin(x)), pytest.param(lambda x: numpy.arccos(x)), pytest.param(lambda x: numpy.arctan(x)), + pytest.param(lambda x: numpy.exp(x)), + pytest.param(lambda x: numpy.expm1(x)), + pytest.param(lambda x: numpy.exp2(x)), # The next test case is only for coverage purposes, to trigger the unsupported method # exception handling pytest.param( @@ -358,6 +361,9 @@ def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expec pytest.param(numpy.arcsin, tracing.NPTracer.arcsin), pytest.param(numpy.arccos, tracing.NPTracer.arccos), pytest.param(numpy.arctan, tracing.NPTracer.arctan), + pytest.param(numpy.exp, tracing.NPTracer.exp), + pytest.param(numpy.expm1, tracing.NPTracer.expm1), + pytest.param(numpy.exp2, tracing.NPTracer.exp2), pytest.param(numpy.dot, tracing.NPTracer.dot), # There is a need to test the case where the function fails, I chose numpy.conjugate which # works on complex types, as we don't talk about complex types for now this looks like a