From ff260b2cd23bdb22b4c6e7fd3b0bbf30e64d862c Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 24 Aug 2021 11:52:35 +0200 Subject: [PATCH] feat(nptracer): add dot tracing abilities - remove no cover from Dot.label - small refactor of BaseTracer to make _sanitize a class method - small refactor of get_ufunc_numpy_output_dtype to manage funcs and ufuncs - add function routing to NPTracer - add dot tracing to NPTracer - small refactor to get tracing functions for numpy funcs and ufuncs --- hdk/common/representation/intermediate.py | 3 +- hdk/common/tracing/base_tracer.py | 13 ++-- hdk/hnumpy/np_dtypes_helpers.py | 26 ++++---- hdk/hnumpy/tracing.py | 79 ++++++++++++++++++----- tests/hnumpy/test_debugging.py | 33 +++++++++- tests/hnumpy/test_tracing.py | 62 ++++++++++++++++-- 6 files changed, 176 insertions(+), 40 deletions(-) diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 2e563b301..d0cfef740 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -351,6 +351,5 @@ class Dot(IntermediateNode): and super().is_equivalent_to(other) ) - # TODO: Coverage will come with the ability to trace the operator in a subsequent PR - def label(self) -> str: # pragma: no cover + def label(self) -> str: return "dot" diff --git a/hdk/common/tracing/base_tracer.py b/hdk/common/tracing/base_tracer.py index 689d88f9a..851362802 100644 --- a/hdk/common/tracing/base_tracer.py +++ b/hdk/common/tracing/base_tracer.py @@ -53,6 +53,11 @@ class BaseTracer(ABC): def _get_mix_values_func(cls): return cls._mix_values_func + def _sanitize(self, inp) -> "BaseTracer": + if not isinstance(inp, BaseTracer): + return self._make_const_input_tracer(inp) + return inp + def instantiate_output_tracers( self, inputs: Iterable[Union["BaseTracer", Any]], @@ -69,13 +74,9 @@ class BaseTracer(ABC): Returns: Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function """ - # For inputs which are actually constant, first convert into a tracer - def sanitize(inp): - if not isinstance(inp, BaseTracer): - return self._make_const_input_tracer(inp) - return inp - sanitized_inputs = [sanitize(inp) for inp in inputs] + # For inputs which are actually constant, first convert into a tracer + sanitized_inputs = [self._sanitize(inp) for inp in inputs] additional_parameters = ( {IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()} diff --git a/hdk/hnumpy/np_dtypes_helpers.py b/hdk/hnumpy/np_dtypes_helpers.py index 820405005..d755bcbcc 100644 --- a/hdk/hnumpy/np_dtypes_helpers.py +++ b/hdk/hnumpy/np_dtypes_helpers.py @@ -2,7 +2,7 @@ from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Union import numpy from numpy.typing import DTypeLike @@ -154,23 +154,25 @@ def get_base_value_for_numpy_or_python_constant_data( return constant_data_value -def get_ufunc_numpy_output_dtype( - ufunc: numpy.ufunc, +def get_numpy_function_output_dtype( + function: Union[numpy.ufunc, Callable], input_dtypes: List[BaseDataType], ) -> List[numpy.dtype]: - """Function to record the output dtype of a numpy.ufunc given some input types. + """Function to record the output dtype of a numpy function given some input types. Args: - ufunc (numpy.ufunc): The numpy.ufunc whose output types need to be recorded - input_dtypes (List[BaseDataType]): Common dtypes in the same order as they will be used with - the ufunc inputs + function (Union[numpy.ufunc, Callable]): The numpy function whose output types need to + be recorded + input_dtypes (List[BaseDataType]): BaseDataTypes in the same order as they will be used with + the function inputs Returns: - List[numpy.dtype]: The ordered numpy dtypes of the ufunc outputs + List[numpy.dtype]: The ordered numpy dtypes of the function outputs """ - assert ( - len(input_dtypes) == ufunc.nin - ), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}" + if isinstance(function, numpy.ufunc): + assert ( + len(input_dtypes) == function.nin + ), f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}" input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes] @@ -183,7 +185,7 @@ def get_ufunc_numpy_output_dtype( dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_numpy_dtypes ) - outputs = ufunc(*dummy_inputs) + outputs = function(*dummy_inputs) if not isinstance(outputs, tuple): outputs = (outputs,) diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index eff5aea12..fbd11b881 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -1,21 +1,21 @@ """hnumpy tracing utilities.""" from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional, Union import numpy from numpy.typing import DTypeLike from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype from ..common.operator_graph import OPGraph -from ..common.representation.intermediate import ArbitraryFunction, Constant +from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters from ..common.values import BaseValue from .np_dtypes_helpers import ( SUPPORTED_NUMPY_DTYPES_CLASS_TYPES, convert_numpy_dtype_to_base_data_type, get_base_value_for_numpy_or_python_constant_data, - get_ufunc_numpy_output_dtype, + get_numpy_function_output_dtype, ) SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple( @@ -39,13 +39,24 @@ class NPTracer(BaseTracer): Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch """ if method == "__call__": - tracing_func = self.get_tracing_func_for_np_ufunc(ufunc) + tracing_func = self.get_tracing_func_for_np_function(ufunc) assert ( len(kwargs) == 0 ), f"hnumpy does not support **kwargs currently for numpy ufuncs, ufunc: {ufunc}" return tracing_func(self, *input_tracers, **kwargs) raise NotImplementedError("Only __call__ method is supported currently") + def __array_function__(self, func, _types, args, kwargs): + """Catch calls to numpy function in routes them to hnp functions if supported. + + Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch + """ + tracing_func = self.get_tracing_func_for_np_function(func) + assert ( + len(kwargs) == 0 + ), f"hnumpy does not support **kwargs currently for numpy functions, func: {func}" + return tracing_func(*args, **kwargs) + def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer": r"""Support numpy astype feature. @@ -77,22 +88,27 @@ class NPTracer(BaseTracer): return output_tracer @staticmethod - def get_tracing_func_for_np_ufunc(ufunc: numpy.ufunc) -> Callable: - """Get the tracing function for a numpy ufunc. + def get_tracing_func_for_np_function(func: Union[numpy.ufunc, Callable]) -> Callable: + """Get the tracing function for a numpy function. Args: - ufunc (numpy.ufunc): The numpy ufunc that will be traced + func (Union[numpy.ufunc, Callable]): The numpy function that will be traced Raises: - NotImplementedError: Raised if the passed ufunc is not supported by NPTracer + NotImplementedError: Raised if the passed function is not supported by NPTracer Returns: - Callable: the tracing function that needs to be called to trace ufunc + Callable: the tracing function that needs to be called to trace func """ - tracing_func = NPTracer.UFUNC_ROUTING.get(ufunc, None) + tracing_func: Optional[Callable] + if isinstance(func, numpy.ufunc): + tracing_func = NPTracer.UFUNC_ROUTING.get(func, None) + else: + tracing_func = NPTracer.FUNC_ROUTING.get(func, None) + if tracing_func is None: raise NotImplementedError( - f"NPTracer does not yet manage the following ufunc: {ufunc.__name__}" + f"NPTracer does not yet manage the following func: {func.__name__}" ) return tracing_func @@ -105,8 +121,8 @@ class NPTracer(BaseTracer): return self.__class__([], NPConstant(constant_data), 0) @staticmethod - def _manage_dtypes(ufunc: numpy.ufunc, *input_tracers: "NPTracer"): - output_dtypes = get_ufunc_numpy_output_dtype( + def _manage_dtypes(ufunc: Union[numpy.ufunc, Callable], *input_tracers: BaseTracer): + output_dtypes = get_numpy_function_output_dtype( ufunc, [input_tracer.output.data_type for input_tracer in input_tracers] ) common_output_dtypes = [ @@ -132,7 +148,9 @@ class NPTracer(BaseTracer): op_name="np.rint", ) output_tracer = self.__class__( - input_tracers, traced_computation=traced_computation, output_index=0 + input_tracers, + traced_computation=traced_computation, + output_index=0, ) return output_tracer @@ -154,7 +172,34 @@ class NPTracer(BaseTracer): op_name="np.sin", ) output_tracer = self.__class__( - input_tracers, traced_computation=traced_computation, output_index=0 + input_tracers, + traced_computation=traced_computation, + output_index=0, + ) + return output_tracer + + def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer": + """Function to trace numpy.dot. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + # input_tracers contains the other tracer of the dot product + dot_inputs = (self, self._sanitize(other_tracer)) + + common_output_dtypes = self._manage_dtypes(numpy.dot, *dot_inputs) + assert len(common_output_dtypes) == 1 + + traced_computation = Dot( + [input_tracer.output for input_tracer in dot_inputs], + common_output_dtypes[0], + delegate_evaluation_function=numpy.dot, + ) + + output_tracer = self.__class__( + dot_inputs, + traced_computation=traced_computation, + output_index=0, ) return output_tracer @@ -163,6 +208,10 @@ class NPTracer(BaseTracer): numpy.sin: sin, } + FUNC_ROUTING: Dict[Callable, Callable] = { + numpy.dot: dot, + } + def trace_numpy_function( function_to_trace: Callable, function_parameters: Dict[str, BaseValue] diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index 65679d228..5ef0aaed0 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -1,11 +1,12 @@ """Test file for hnumpy debugging functions""" +import numpy import pytest from hdk.common.data_types.integers import Integer from hdk.common.debugging import draw_graph, get_printable_graph from hdk.common.extensions.table import LookupTable -from hdk.common.values import ClearValue, EncryptedValue +from hdk.common.values import ClearValue, EncryptedTensor, EncryptedValue from hdk.hnumpy import tracing LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11]) @@ -178,6 +179,36 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph ) +@pytest.mark.parametrize( + "lambda_f,params,ref_graph_str", + [ + # pylint: disable=unnecessary-lambda + ( + lambda x, y: numpy.dot(x, y), + { + "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)), + "y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)), + }, + "\n%0 = x\n%1 = y\n%2 = Dot(0, 1)\nreturn(%2)", + ), + # pylint: enable=unnecessary-lambda + ], +) +def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): + "Test hnumpy get_printable_graph and draw_graph on graphs with dot" + graph = tracing.trace_numpy_function(lambda_f, params) + + draw_graph(graph, show=False) + + str_of_the_graph = get_printable_graph(graph) + + assert str_of_the_graph == ref_graph_str, ( + f"\n==================\nGot {str_of_the_graph}" + f"\n==================\nExpected {ref_graph_str}" + f"\n==================\n" + ) + + # Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b # returning 23b), since they are replaced later by the real bitwidths computed on the # dataset diff --git a/tests/hnumpy/test_tracing.py b/tests/hnumpy/test_tracing.py index 529b38fd5..338591d71 100644 --- a/tests/hnumpy/test_tracing.py +++ b/tests/hnumpy/test_tracing.py @@ -291,10 +291,64 @@ def test_trace_hnumpy_supported_ufuncs( @pytest.mark.parametrize( - "np_ufunc,expected_tracing_func", + "function_to_trace,inputs,expected_output_node,expected_output_value", + [ + # pylint: disable=unnecessary-lambda + pytest.param( + lambda x, y: numpy.dot(x, y), + { + "x": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)), + "y": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)), + }, + ir.Dot, + EncryptedValue(Integer(32, False)), + ), + pytest.param( + lambda x, y: numpy.dot(x, y), + { + "x": EncryptedTensor(Float(64), shape=(42,)), + "y": EncryptedTensor(Float(64), shape=(10,)), + }, + ir.Dot, + EncryptedValue(Float(64)), + ), + pytest.param( + lambda x, y: numpy.dot(x, y), + { + "x": ClearTensor(Integer(64, is_signed=True), shape=(6,)), + "y": ClearTensor(Integer(64, is_signed=True), shape=(6,)), + }, + ir.Dot, + ClearValue(Integer(64, is_signed=True)), + ), + pytest.param( + lambda x: numpy.dot(x, numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)), + { + "x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)), + }, + ir.Dot, + EncryptedValue(Integer(64, True)), + ), + # pylint: enable=unnecessary-lambda + ], +) +def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expected_output_value): + """Function to test dot tracing""" + + op_graph = tracing.trace_numpy_function(function_to_trace, inputs) + + assert len(op_graph.output_nodes) == 1 + assert isinstance(op_graph.output_nodes[0], expected_output_node) + assert len(op_graph.output_nodes[0].outputs) == 1 + assert op_graph.output_nodes[0].outputs[0] == expected_output_value + + +@pytest.mark.parametrize( + "np_function,expected_tracing_func", [ pytest.param(numpy.rint, tracing.NPTracer.rint), pytest.param(numpy.sin, tracing.NPTracer.sin), + pytest.param(numpy.dot, tracing.NPTracer.dot), # There is a need to test the case where the function fails, I chose numpy.conjugate which # works on complex types, as we don't talk about complex types for now this looks like a # good long term candidate to check for an unsupported function @@ -303,9 +357,9 @@ def test_trace_hnumpy_supported_ufuncs( ), ], ) -def test_nptracer_get_tracing_func_for_np_ufunc(np_ufunc, expected_tracing_func): - """Test NPTracer get_tracing_func_for_np_ufunc""" - assert tracing.NPTracer.get_tracing_func_for_np_ufunc(np_ufunc) == expected_tracing_func +def test_nptracer_get_tracing_func_for_np_functions(np_function, expected_tracing_func): + """Test NPTracer get_tracing_func_for_np_function""" + assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func @pytest.mark.parametrize(