mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
dev(NPTracer): add support for sin
- re-organize numpy tracing tests refs #126
This commit is contained in:
@@ -113,8 +113,31 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def sin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Function to trace numpy.sin.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert len(input_tracers) == 1
|
||||
common_output_dtypes = self._manage_dtypes(numpy.sin, *input_tracers)
|
||||
assert len(common_output_dtypes) == 1
|
||||
|
||||
traced_computation = ir.ArbitraryFunction(
|
||||
input_base_value=input_tracers[0].output,
|
||||
arbitrary_func=numpy.sin,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="numpy.sin",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
input_tracers, traced_computation=traced_computation, output_index=0
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
UFUNC_ROUTING: Mapping[numpy.ufunc, Callable] = {
|
||||
numpy.rint: rint,
|
||||
numpy.sin: sin,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -205,52 +205,51 @@ def test_tracing_astype(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,inputs,expected_output_node,expected_output_value",
|
||||
"function_to_trace",
|
||||
[
|
||||
# We cannot call trace_numpy_function on some numpy function as getting the signature for
|
||||
# these functions fails, so we wrap it in a lambda
|
||||
# pylint: disable=unnecessary-lambda
|
||||
pytest.param(lambda x: numpy.rint(x)),
|
||||
pytest.param(lambda x: numpy.sin(x)),
|
||||
# The next test case is only for coverage purposes, to trigger the unsupported method
|
||||
# exception handling
|
||||
pytest.param(
|
||||
lambda x: numpy.add.reduce(x),
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"inputs,expected_output_node,expected_output_value",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Integer(7, is_signed=False))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedValue(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Integer(32, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedValue(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Integer(64, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedValue(Float(64)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Integer(128, is_signed=True))},
|
||||
ir.ArbitraryFunction,
|
||||
None,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.rint(x),
|
||||
{"x": EncryptedValue(Float(64))},
|
||||
ir.ArbitraryFunction,
|
||||
EncryptedValue(Float(64)),
|
||||
),
|
||||
# The next test case is only for coverage purposes, to trigger the unsupported method
|
||||
# exception handling
|
||||
pytest.param(
|
||||
lambda x: numpy.add.reduce(x),
|
||||
{"x": EncryptedValue(Integer(32, is_signed=True))},
|
||||
None,
|
||||
None,
|
||||
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
||||
),
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
def test_trace_hnumpy_supported_ufuncs(
|
||||
@@ -269,6 +268,7 @@ def test_trace_hnumpy_supported_ufuncs(
|
||||
"np_ufunc,expected_tracing_func",
|
||||
[
|
||||
pytest.param(numpy.rint, tracing.NPTracer.rint),
|
||||
pytest.param(numpy.sin, tracing.NPTracer.sin),
|
||||
# There is a need to test the case where the function fails, I chose numpy.conjugate which
|
||||
# works on complex types, as we don't talk about complex types for now this looks like a
|
||||
# good long term candidate to check for an unsupported function
|
||||
|
||||
Reference in New Issue
Block a user