feat: adding arcsin, arccos, arctan in the managed operators

refs #126
closes #257
This commit is contained in:
Benoit Chevallier-Mames
2021-08-31 16:12:11 +02:00
committed by Benoit Chevallier
parent 4c77f08854
commit e90df9c0b7
2 changed files with 33 additions and 0 deletions

View File

@@ -188,6 +188,30 @@ class NPTracer(BaseTracer):
"""
return self._unary_operator(numpy.tan, "np.tan", *input_tracers, **kwargs)
def arcsin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.arcsin.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.arcsin, "np.arcsin", *input_tracers, **kwargs)
def arccos(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.arccos.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.arccos, "np.arccos", *input_tracers, **kwargs)
def arctan(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.arctan.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.arctan, "np.arctan", *input_tracers, **kwargs)
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
"""Function to trace numpy.dot.
@@ -218,6 +242,9 @@ class NPTracer(BaseTracer):
numpy.sin: sin,
numpy.cos: cos,
numpy.tan: tan,
numpy.arcsin: arcsin,
numpy.arccos: arccos,
numpy.arctan: arctan,
}
FUNC_ROUTING: Dict[Callable, Callable] = {

View File

@@ -240,6 +240,9 @@ def test_tracing_astype(
pytest.param(lambda x: numpy.sin(x)),
pytest.param(lambda x: numpy.cos(x)),
pytest.param(lambda x: numpy.tan(x)),
pytest.param(lambda x: numpy.arcsin(x)),
pytest.param(lambda x: numpy.arccos(x)),
pytest.param(lambda x: numpy.arctan(x)),
# The next test case is only for coverage purposes, to trigger the unsupported method
# exception handling
pytest.param(
@@ -352,6 +355,9 @@ def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expec
pytest.param(numpy.sin, tracing.NPTracer.sin),
pytest.param(numpy.cos, tracing.NPTracer.cos),
pytest.param(numpy.tan, tracing.NPTracer.tan),
pytest.param(numpy.arcsin, tracing.NPTracer.arcsin),
pytest.param(numpy.arccos, tracing.NPTracer.arccos),
pytest.param(numpy.arctan, tracing.NPTracer.arctan),
pytest.param(numpy.dot, tracing.NPTracer.dot),
# There is a need to test the case where the function fails, I chose numpy.conjugate which
# works on complex types, as we don't talk about complex types for now this looks like a