mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: adding exp functions in the managed operators
refs #126 closes #260
This commit is contained in:
committed by
Benoit Chevallier
parent
e90df9c0b7
commit
0d84f8c5f5
@@ -212,6 +212,30 @@ class NPTracer(BaseTracer):
|
||||
"""
|
||||
return self._unary_operator(numpy.arctan, "np.arctan", *input_tracers, **kwargs)
|
||||
|
||||
def exp(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.exp.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.exp, "np.exp", *input_tracers, **kwargs)
|
||||
|
||||
def expm1(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.expm1.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.expm1, "np.expm1", *input_tracers, **kwargs)
|
||||
|
||||
def exp2(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.exp2.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.exp2, "np.exp2", *input_tracers, **kwargs)
|
||||
|
||||
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.dot.
|
||||
|
||||
@@ -245,6 +269,9 @@ class NPTracer(BaseTracer):
|
||||
numpy.arcsin: arcsin,
|
||||
numpy.arccos: arccos,
|
||||
numpy.arctan: arctan,
|
||||
numpy.exp: exp,
|
||||
numpy.expm1: expm1,
|
||||
numpy.exp2: exp2,
|
||||
}
|
||||
|
||||
FUNC_ROUTING: Dict[Callable, Callable] = {
|
||||
|
||||
@@ -243,6 +243,9 @@ def test_tracing_astype(
|
||||
pytest.param(lambda x: numpy.arcsin(x)),
|
||||
pytest.param(lambda x: numpy.arccos(x)),
|
||||
pytest.param(lambda x: numpy.arctan(x)),
|
||||
pytest.param(lambda x: numpy.exp(x)),
|
||||
pytest.param(lambda x: numpy.expm1(x)),
|
||||
pytest.param(lambda x: numpy.exp2(x)),
|
||||
# The next test case is only for coverage purposes, to trigger the unsupported method
|
||||
# exception handling
|
||||
pytest.param(
|
||||
@@ -358,6 +361,9 @@ def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expec
|
||||
pytest.param(numpy.arcsin, tracing.NPTracer.arcsin),
|
||||
pytest.param(numpy.arccos, tracing.NPTracer.arccos),
|
||||
pytest.param(numpy.arctan, tracing.NPTracer.arctan),
|
||||
pytest.param(numpy.exp, tracing.NPTracer.exp),
|
||||
pytest.param(numpy.expm1, tracing.NPTracer.expm1),
|
||||
pytest.param(numpy.exp2, tracing.NPTracer.exp2),
|
||||
pytest.param(numpy.dot, tracing.NPTracer.dot),
|
||||
# There is a need to test the case where the function fails, I chose numpy.conjugate which
|
||||
# works on complex types, as we don't talk about complex types for now this looks like a
|
||||
|
||||
Reference in New Issue
Block a user