mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
committed by
Benoit Chevallier
parent
758a5727dc
commit
802f7943b1
@@ -130,22 +130,24 @@ class NPTracer(BaseTracer):
|
||||
]
|
||||
return common_output_dtypes
|
||||
|
||||
def rint(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.rint.
|
||||
def _unary_operator(
|
||||
self, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
|
||||
) -> "NPTracer":
|
||||
"""Function to trace an unary operator.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert len(input_tracers) == 1
|
||||
common_output_dtypes = self._manage_dtypes(numpy.rint, *input_tracers)
|
||||
common_output_dtypes = self._manage_dtypes(unary_operator, *input_tracers)
|
||||
assert len(common_output_dtypes) == 1
|
||||
|
||||
traced_computation = ArbitraryFunction(
|
||||
input_base_value=input_tracers[0].output,
|
||||
arbitrary_func=numpy.rint,
|
||||
arbitrary_func=unary_operator,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="np.rint",
|
||||
op_name=unary_operator_string,
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
input_tracers,
|
||||
@@ -154,29 +156,29 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def rint(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.rint.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.rint, "np.rint", *input_tracers, **kwargs)
|
||||
|
||||
def sin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.sin.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert len(input_tracers) == 1
|
||||
common_output_dtypes = self._manage_dtypes(numpy.sin, *input_tracers)
|
||||
assert len(common_output_dtypes) == 1
|
||||
return self._unary_operator(numpy.sin, "np.sin", *input_tracers, **kwargs)
|
||||
|
||||
traced_computation = ArbitraryFunction(
|
||||
input_base_value=input_tracers[0].output,
|
||||
arbitrary_func=numpy.sin,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="np.sin",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
input_tracers,
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
)
|
||||
return output_tracer
|
||||
def cos(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.cos.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
return self._unary_operator(numpy.cos, "np.cos", *input_tracers, **kwargs)
|
||||
|
||||
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.dot.
|
||||
@@ -206,6 +208,7 @@ class NPTracer(BaseTracer):
|
||||
UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {
|
||||
numpy.rint: rint,
|
||||
numpy.sin: sin,
|
||||
numpy.cos: cos,
|
||||
}
|
||||
|
||||
FUNC_ROUTING: Dict[Callable, Callable] = {
|
||||
|
||||
@@ -238,6 +238,7 @@ def test_tracing_astype(
|
||||
# pylint: disable=unnecessary-lambda
|
||||
pytest.param(lambda x: numpy.rint(x)),
|
||||
pytest.param(lambda x: numpy.sin(x)),
|
||||
pytest.param(lambda x: numpy.cos(x)),
|
||||
# The next test case is only for coverage purposes, to trigger the unsupported method
|
||||
# exception handling
|
||||
pytest.param(
|
||||
@@ -348,6 +349,7 @@ def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expec
|
||||
[
|
||||
pytest.param(numpy.rint, tracing.NPTracer.rint),
|
||||
pytest.param(numpy.sin, tracing.NPTracer.sin),
|
||||
pytest.param(numpy.cos, tracing.NPTracer.cos),
|
||||
pytest.param(numpy.dot, tracing.NPTracer.dot),
|
||||
# There is a need to test the case where the function fails, I chose numpy.conjugate which
|
||||
# works on complex types, as we don't talk about complex types for now this looks like a
|
||||
|
||||
Reference in New Issue
Block a user