From d48c4dba32597eaef4e3a95411b8723e769ac77a Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 16 Aug 2021 12:45:52 +0200 Subject: [PATCH] dev(NPTracer): add support for sin - re-organize numpy tracing tests refs #126 --- hdk/hnumpy/tracing.py | 23 +++++++++++++++++++++++ tests/hnumpy/test_tracing.py | 32 ++++++++++++++++---------------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index 1a4996061..74416b2c7 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -113,8 +113,31 @@ class NPTracer(BaseTracer): ) return output_tracer + def sin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer": + """Function to trace numpy.sin. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + assert len(input_tracers) == 1 + common_output_dtypes = self._manage_dtypes(numpy.sin, *input_tracers) + assert len(common_output_dtypes) == 1 + + traced_computation = ir.ArbitraryFunction( + input_base_value=input_tracers[0].output, + arbitrary_func=numpy.sin, + output_dtype=common_output_dtypes[0], + op_kwargs=deepcopy(kwargs), + op_name="numpy.sin", + ) + output_tracer = self.__class__( + input_tracers, traced_computation=traced_computation, output_index=0 + ) + return output_tracer + UFUNC_ROUTING: Mapping[numpy.ufunc, Callable] = { numpy.rint: rint, + numpy.sin: sin, } diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index e361d2dd1..3ede097c9 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -205,52 +205,51 @@ def test_tracing_astype( @pytest.mark.parametrize( - "function_to_trace,inputs,expected_output_node,expected_output_value", + "function_to_trace", [ # We cannot call trace_numpy_function on some numpy function as getting the signature for # these functions fails, so we wrap it in a lambda # pylint: disable=unnecessary-lambda + pytest.param(lambda x: numpy.rint(x)), + pytest.param(lambda x: numpy.sin(x)), + # The next test case is only for coverage purposes, to trigger the unsupported method + # exception handling + pytest.param( + lambda x: numpy.add.reduce(x), + marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), + ), + # pylint: enable=unnecessary-lambda + ], +) +@pytest.mark.parametrize( + "inputs,expected_output_node,expected_output_value", + [ pytest.param( - lambda x: numpy.rint(x), {"x": EncryptedValue(Integer(7, is_signed=False))}, ir.ArbitraryFunction, EncryptedValue(Float(64)), ), pytest.param( - lambda x: numpy.rint(x), {"x": EncryptedValue(Integer(32, is_signed=True))}, ir.ArbitraryFunction, EncryptedValue(Float(64)), ), pytest.param( - lambda x: numpy.rint(x), {"x": EncryptedValue(Integer(64, is_signed=True))}, ir.ArbitraryFunction, EncryptedValue(Float(64)), ), pytest.param( - lambda x: numpy.rint(x), {"x": EncryptedValue(Integer(128, is_signed=True))}, ir.ArbitraryFunction, None, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), ), pytest.param( - lambda x: numpy.rint(x), {"x": EncryptedValue(Float(64))}, ir.ArbitraryFunction, EncryptedValue(Float(64)), ), - # The next test case is only for coverage purposes, to trigger the unsupported method - # exception handling - pytest.param( - lambda x: numpy.add.reduce(x), - {"x": EncryptedValue(Integer(32, is_signed=True))}, - None, - None, - marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), - ), - # pylint: enable=unnecessary-lambda ], ) def test_trace_hnumpy_supported_ufuncs( @@ -269,6 +268,7 @@ def test_trace_hnumpy_supported_ufuncs( "np_ufunc,expected_tracing_func", [ pytest.param(numpy.rint, tracing.NPTracer.rint), + pytest.param(numpy.sin, tracing.NPTracer.sin), # There is a need to test the case where the function fails, I chose numpy.conjugate which # works on complex types, as we don't talk about complex types for now this looks like a # good long term candidate to check for an unsupported function