dev(NPTracer): add support for sin

- re-organize numpy tracing tests

refs #126
This commit is contained in:
Arthur Meyre
2021-08-16 12:45:52 +02:00
parent 825d6422d0
commit d48c4dba32
2 changed files with 39 additions and 16 deletions

View File

@@ -113,8 +113,31 @@ class NPTracer(BaseTracer):
)
return output_tracer
def sin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.sin.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert len(input_tracers) == 1
common_output_dtypes = self._manage_dtypes(numpy.sin, *input_tracers)
assert len(common_output_dtypes) == 1
traced_computation = ir.ArbitraryFunction(
input_base_value=input_tracers[0].output,
arbitrary_func=numpy.sin,
output_dtype=common_output_dtypes[0],
op_kwargs=deepcopy(kwargs),
op_name="numpy.sin",
)
output_tracer = self.__class__(
input_tracers, traced_computation=traced_computation, output_index=0
)
return output_tracer
UFUNC_ROUTING: Mapping[numpy.ufunc, Callable] = {
numpy.rint: rint,
numpy.sin: sin,
}

View File

@@ -205,52 +205,51 @@ def test_tracing_astype(
@pytest.mark.parametrize(
"function_to_trace,inputs,expected_output_node,expected_output_value",
"function_to_trace",
[
# We cannot call trace_numpy_function on some numpy function as getting the signature for
# these functions fails, so we wrap it in a lambda
# pylint: disable=unnecessary-lambda
pytest.param(lambda x: numpy.rint(x)),
pytest.param(lambda x: numpy.sin(x)),
# The next test case is only for coverage purposes, to trigger the unsupported method
# exception handling
pytest.param(
lambda x: numpy.add.reduce(x),
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
# pylint: enable=unnecessary-lambda
],
)
@pytest.mark.parametrize(
"inputs,expected_output_node,expected_output_value",
[
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Integer(7, is_signed=False))},
ir.ArbitraryFunction,
EncryptedValue(Float(64)),
),
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Integer(32, is_signed=True))},
ir.ArbitraryFunction,
EncryptedValue(Float(64)),
),
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Integer(64, is_signed=True))},
ir.ArbitraryFunction,
EncryptedValue(Float(64)),
),
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Integer(128, is_signed=True))},
ir.ArbitraryFunction,
None,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Float(64))},
ir.ArbitraryFunction,
EncryptedValue(Float(64)),
),
# The next test case is only for coverage purposes, to trigger the unsupported method
# exception handling
pytest.param(
lambda x: numpy.add.reduce(x),
{"x": EncryptedValue(Integer(32, is_signed=True))},
None,
None,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
# pylint: enable=unnecessary-lambda
],
)
def test_trace_hnumpy_supported_ufuncs(
@@ -269,6 +268,7 @@ def test_trace_hnumpy_supported_ufuncs(
"np_ufunc,expected_tracing_func",
[
pytest.param(numpy.rint, tracing.NPTracer.rint),
pytest.param(numpy.sin, tracing.NPTracer.sin),
# There is a need to test the case where the function fails, I chose numpy.conjugate which
# works on complex types, as we don't talk about complex types for now this looks like a
# good long term candidate to check for an unsupported function