feat(nptracer): add dot tracing abilities

- remove no cover from Dot.label
- small refactor of BaseTracer to make _sanitize a class method
- small refactor of get_ufunc_numpy_output_dtype to manage funcs and ufuncs
- add function routing to NPTracer
- add dot tracing to NPTracer
- small refactor to get tracing functions for numpy funcs and ufuncs
This commit is contained in:
Arthur Meyre
2021-08-24 11:52:35 +02:00
parent 6d663ef63d
commit ff260b2cd2
6 changed files with 176 additions and 40 deletions

View File

@@ -351,6 +351,5 @@ class Dot(IntermediateNode):
and super().is_equivalent_to(other)
)
# TODO: Coverage will come with the ability to trace the operator in a subsequent PR
def label(self) -> str: # pragma: no cover
def label(self) -> str:
return "dot"

View File

@@ -53,6 +53,11 @@ class BaseTracer(ABC):
def _get_mix_values_func(cls):
return cls._mix_values_func
def _sanitize(self, inp) -> "BaseTracer":
if not isinstance(inp, BaseTracer):
return self._make_const_input_tracer(inp)
return inp
def instantiate_output_tracers(
self,
inputs: Iterable[Union["BaseTracer", Any]],
@@ -69,13 +74,9 @@ class BaseTracer(ABC):
Returns:
Tuple[BaseTracer, ...]: A tuple containing an BaseTracer per output function
"""
# For inputs which are actually constant, first convert into a tracer
def sanitize(inp):
if not isinstance(inp, BaseTracer):
return self._make_const_input_tracer(inp)
return inp
sanitized_inputs = [sanitize(inp) for inp in inputs]
# For inputs which are actually constant, first convert into a tracer
sanitized_inputs = [self._sanitize(inp) for inp in inputs]
additional_parameters = (
{IR_MIX_VALUES_FUNC_ARG_NAME: self._get_mix_values_func()}

View File

@@ -2,7 +2,7 @@
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Union
import numpy
from numpy.typing import DTypeLike
@@ -154,23 +154,25 @@ def get_base_value_for_numpy_or_python_constant_data(
return constant_data_value
def get_ufunc_numpy_output_dtype(
ufunc: numpy.ufunc,
def get_numpy_function_output_dtype(
function: Union[numpy.ufunc, Callable],
input_dtypes: List[BaseDataType],
) -> List[numpy.dtype]:
"""Function to record the output dtype of a numpy.ufunc given some input types.
"""Function to record the output dtype of a numpy function given some input types.
Args:
ufunc (numpy.ufunc): The numpy.ufunc whose output types need to be recorded
input_dtypes (List[BaseDataType]): Common dtypes in the same order as they will be used with
the ufunc inputs
function (Union[numpy.ufunc, Callable]): The numpy function whose output types need to
be recorded
input_dtypes (List[BaseDataType]): BaseDataTypes in the same order as they will be used with
the function inputs
Returns:
List[numpy.dtype]: The ordered numpy dtypes of the ufunc outputs
List[numpy.dtype]: The ordered numpy dtypes of the function outputs
"""
assert (
len(input_dtypes) == ufunc.nin
), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}"
if isinstance(function, numpy.ufunc):
assert (
len(input_dtypes) == function.nin
), f"Expected {function.nin} types, got {len(input_dtypes)}: {input_dtypes}"
input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes]
@@ -183,7 +185,7 @@ def get_ufunc_numpy_output_dtype(
dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_numpy_dtypes
)
outputs = ufunc(*dummy_inputs)
outputs = function(*dummy_inputs)
if not isinstance(outputs, tuple):
outputs = (outputs,)

View File

@@ -1,21 +1,21 @@
"""hnumpy tracing utilities."""
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Union
import numpy
from numpy.typing import DTypeLike
from ..common.data_types.dtypes_helpers import mix_values_determine_holding_dtype
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import ArbitraryFunction, Constant
from ..common.representation.intermediate import ArbitraryFunction, Constant, Dot
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from ..common.values import BaseValue
from .np_dtypes_helpers import (
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
convert_numpy_dtype_to_base_data_type,
get_base_value_for_numpy_or_python_constant_data,
get_ufunc_numpy_output_dtype,
get_numpy_function_output_dtype,
)
SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple(
@@ -39,13 +39,24 @@ class NPTracer(BaseTracer):
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
"""
if method == "__call__":
tracing_func = self.get_tracing_func_for_np_ufunc(ufunc)
tracing_func = self.get_tracing_func_for_np_function(ufunc)
assert (
len(kwargs) == 0
), f"hnumpy does not support **kwargs currently for numpy ufuncs, ufunc: {ufunc}"
return tracing_func(self, *input_tracers, **kwargs)
raise NotImplementedError("Only __call__ method is supported currently")
def __array_function__(self, func, _types, args, kwargs):
"""Catch calls to numpy function in routes them to hnp functions if supported.
Read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
"""
tracing_func = self.get_tracing_func_for_np_function(func)
assert (
len(kwargs) == 0
), f"hnumpy does not support **kwargs currently for numpy functions, func: {func}"
return tracing_func(*args, **kwargs)
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
r"""Support numpy astype feature.
@@ -77,22 +88,27 @@ class NPTracer(BaseTracer):
return output_tracer
@staticmethod
def get_tracing_func_for_np_ufunc(ufunc: numpy.ufunc) -> Callable:
"""Get the tracing function for a numpy ufunc.
def get_tracing_func_for_np_function(func: Union[numpy.ufunc, Callable]) -> Callable:
"""Get the tracing function for a numpy function.
Args:
ufunc (numpy.ufunc): The numpy ufunc that will be traced
func (Union[numpy.ufunc, Callable]): The numpy function that will be traced
Raises:
NotImplementedError: Raised if the passed ufunc is not supported by NPTracer
NotImplementedError: Raised if the passed function is not supported by NPTracer
Returns:
Callable: the tracing function that needs to be called to trace ufunc
Callable: the tracing function that needs to be called to trace func
"""
tracing_func = NPTracer.UFUNC_ROUTING.get(ufunc, None)
tracing_func: Optional[Callable]
if isinstance(func, numpy.ufunc):
tracing_func = NPTracer.UFUNC_ROUTING.get(func, None)
else:
tracing_func = NPTracer.FUNC_ROUTING.get(func, None)
if tracing_func is None:
raise NotImplementedError(
f"NPTracer does not yet manage the following ufunc: {ufunc.__name__}"
f"NPTracer does not yet manage the following func: {func.__name__}"
)
return tracing_func
@@ -105,8 +121,8 @@ class NPTracer(BaseTracer):
return self.__class__([], NPConstant(constant_data), 0)
@staticmethod
def _manage_dtypes(ufunc: numpy.ufunc, *input_tracers: "NPTracer"):
output_dtypes = get_ufunc_numpy_output_dtype(
def _manage_dtypes(ufunc: Union[numpy.ufunc, Callable], *input_tracers: BaseTracer):
output_dtypes = get_numpy_function_output_dtype(
ufunc, [input_tracer.output.data_type for input_tracer in input_tracers]
)
common_output_dtypes = [
@@ -132,7 +148,9 @@ class NPTracer(BaseTracer):
op_name="np.rint",
)
output_tracer = self.__class__(
input_tracers, traced_computation=traced_computation, output_index=0
input_tracers,
traced_computation=traced_computation,
output_index=0,
)
return output_tracer
@@ -154,7 +172,34 @@ class NPTracer(BaseTracer):
op_name="np.sin",
)
output_tracer = self.__class__(
input_tracers, traced_computation=traced_computation, output_index=0
input_tracers,
traced_computation=traced_computation,
output_index=0,
)
return output_tracer
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
"""Function to trace numpy.dot.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
# input_tracers contains the other tracer of the dot product
dot_inputs = (self, self._sanitize(other_tracer))
common_output_dtypes = self._manage_dtypes(numpy.dot, *dot_inputs)
assert len(common_output_dtypes) == 1
traced_computation = Dot(
[input_tracer.output for input_tracer in dot_inputs],
common_output_dtypes[0],
delegate_evaluation_function=numpy.dot,
)
output_tracer = self.__class__(
dot_inputs,
traced_computation=traced_computation,
output_index=0,
)
return output_tracer
@@ -163,6 +208,10 @@ class NPTracer(BaseTracer):
numpy.sin: sin,
}
FUNC_ROUTING: Dict[Callable, Callable] = {
numpy.dot: dot,
}
def trace_numpy_function(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]

View File

@@ -1,11 +1,12 @@
"""Test file for hnumpy debugging functions"""
import numpy
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.common.extensions.table import LookupTable
from hdk.common.values import ClearValue, EncryptedValue
from hdk.common.values import ClearValue, EncryptedTensor, EncryptedValue
from hdk.hnumpy import tracing
LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11])
@@ -178,6 +179,36 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph
)
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
# pylint: disable=unnecessary-lambda
(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
"y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
},
"\n%0 = x\n%1 = y\n%2 = Dot(0, 1)\nreturn(%2)",
),
# pylint: enable=unnecessary-lambda
],
)
def test_hnumpy_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
"Test hnumpy get_printable_graph and draw_graph on graphs with dot"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot {str_of_the_graph}"
f"\n==================\nExpected {ref_graph_str}"
f"\n==================\n"
)
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
# returning 23b), since they are replaced later by the real bitwidths computed on the
# dataset

View File

@@ -291,10 +291,64 @@ def test_trace_hnumpy_supported_ufuncs(
@pytest.mark.parametrize(
"np_ufunc,expected_tracing_func",
"function_to_trace,inputs,expected_output_node,expected_output_value",
[
# pylint: disable=unnecessary-lambda
pytest.param(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)),
"y": EncryptedTensor(Integer(7, is_signed=False), shape=(10,)),
},
ir.Dot,
EncryptedValue(Integer(32, False)),
),
pytest.param(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(Float(64), shape=(42,)),
"y": EncryptedTensor(Float(64), shape=(10,)),
},
ir.Dot,
EncryptedValue(Float(64)),
),
pytest.param(
lambda x, y: numpy.dot(x, y),
{
"x": ClearTensor(Integer(64, is_signed=True), shape=(6,)),
"y": ClearTensor(Integer(64, is_signed=True), shape=(6,)),
},
ir.Dot,
ClearValue(Integer(64, is_signed=True)),
),
pytest.param(
lambda x: numpy.dot(x, numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)),
{
"x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)),
},
ir.Dot,
EncryptedValue(Integer(64, True)),
),
# pylint: enable=unnecessary-lambda
],
)
def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expected_output_value):
"""Function to test dot tracing"""
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
assert len(op_graph.output_nodes) == 1
assert isinstance(op_graph.output_nodes[0], expected_output_node)
assert len(op_graph.output_nodes[0].outputs) == 1
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
@pytest.mark.parametrize(
"np_function,expected_tracing_func",
[
pytest.param(numpy.rint, tracing.NPTracer.rint),
pytest.param(numpy.sin, tracing.NPTracer.sin),
pytest.param(numpy.dot, tracing.NPTracer.dot),
# There is a need to test the case where the function fails, I chose numpy.conjugate which
# works on complex types, as we don't talk about complex types for now this looks like a
# good long term candidate to check for an unsupported function
@@ -303,9 +357,9 @@ def test_trace_hnumpy_supported_ufuncs(
),
],
)
def test_nptracer_get_tracing_func_for_np_ufunc(np_ufunc, expected_tracing_func):
"""Test NPTracer get_tracing_func_for_np_ufunc"""
assert tracing.NPTracer.get_tracing_func_for_np_ufunc(np_ufunc) == expected_tracing_func
def test_nptracer_get_tracing_func_for_np_functions(np_function, expected_tracing_func):
"""Test NPTracer get_tracing_func_for_np_function"""
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
@pytest.mark.parametrize(