feat(hnp-tracing): add support for ufunc routing to NPTracer

- start tracing numpy.rint and manage dtypes
- update BaseTracer to accept iterables as inputs, because NPTracer does
not get list givent the way numpy sends arguments to the functions
This commit is contained in:
Arthur Meyre
2021-08-10 17:09:59 +02:00
parent 19e68589d1
commit 7bdcfabbfe
4 changed files with 164 additions and 18 deletions

View File

@@ -1,7 +1,7 @@
"""This file holds the code that can be shared between tracers"""
from abc import ABC
from typing import List, Tuple, Type, Union
from typing import Iterable, List, Tuple, Type, Union
from ..data_types import BaseValue
from ..data_types.scalars import Scalars
@@ -17,17 +17,17 @@ class BaseTracer(ABC):
def __init__(
self,
inputs: List["BaseTracer"],
inputs: Iterable["BaseTracer"],
traced_computation: ir.IntermediateNode,
output_index: int,
) -> None:
self.inputs = inputs
self.inputs = list(inputs)
self.traced_computation = traced_computation
self.output = traced_computation.outputs[output_index]
def instantiate_output_tracers(
self,
inputs: List[Union["BaseTracer", Scalars]],
inputs: Iterable[Union["BaseTracer", Scalars]],
computation_to_trace: Type[ir.IntermediateNode],
) -> Tuple["BaseTracer", ...]:
"""Helper functions to instantiate all output BaseTracer for a given computation

View File

@@ -1,7 +1,7 @@
"""File to hold code to manage package and numpy dtypes"""
from copy import deepcopy
from typing import List, Union
from typing import List
import numpy
from numpy.typing import DTypeLike
@@ -94,14 +94,14 @@ def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dty
def get_ufunc_numpy_output_dtype(
ufunc: numpy.ufunc,
input_dtypes: Union[List[numpy.dtype], List[BaseDataType]],
input_dtypes: List[BaseDataType],
) -> List[numpy.dtype]:
"""Function to record the output dtype of a numpy.ufunc given some input types
Args:
ufunc (numpy.ufunc): The numpy.ufunc whose output types need to be recorded
input_dtypes (Union[List[numpy.dtype], List[BaseDataType]]): Either numpy or common dtypes
in the same order as they will be used with the ufunc inputs
input_dtypes (List[BaseDataType]): Common dtypes in the same order as they will be used with
the ufunc inputs
Returns:
List[numpy.dtype]: The ordered numpy dtypes of the ufunc outputs
@@ -110,12 +110,7 @@ def get_ufunc_numpy_output_dtype(
len(input_dtypes) == ufunc.nin
), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}"
input_dtypes = [
numpy.dtype(convert_common_dtype_to_numpy_dtype(dtype))
if not isinstance(dtype, numpy.dtype)
else dtype
for dtype in input_dtypes
]
input_numpy_dtypes = [convert_common_dtype_to_numpy_dtype(dtype) for dtype in input_dtypes]
# Store numpy old error settings and ignore all errors in this function
# We ignore errors as we may call functions with invalid inputs just to get the proper output
@@ -123,7 +118,7 @@ def get_ufunc_numpy_output_dtype(
old_numpy_err_settings = numpy.seterr(all="ignore")
dummy_inputs = tuple(
dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_dtypes
dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_numpy_dtypes
)
outputs = ufunc(*dummy_inputs)

View File

@@ -1,5 +1,5 @@
"""hnumpy tracing utilities"""
from typing import Callable, Dict
from typing import Callable, Dict, Mapping
import numpy
from numpy.typing import DTypeLike
@@ -8,12 +8,28 @@ from ..common.data_types import BaseValue
from ..common.operator_graph import OPGraph
from ..common.representation import intermediate as ir
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from .np_dtypes_helpers import convert_numpy_dtype_to_common_dtype
from .np_dtypes_helpers import (
convert_numpy_dtype_to_common_dtype,
get_ufunc_numpy_output_dtype,
)
class NPTracer(BaseTracer):
"""Tracer class for numpy operations"""
def __array_ufunc__(self, ufunc, method, *input_tracers, **kwargs):
"""
Catch calls to numpy ufunc and routes them to tracing functions if supported
read more: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
"""
if method == "__call__":
tracing_func = self.get_tracing_func_for_np_ufunc(ufunc)
assert (
len(kwargs) == 0
), f"hnumpy does not support **kwargs currently for numpy ufuncs, ufunc: {ufunc}"
return tracing_func(self, *input_tracers, **kwargs)
raise NotImplementedError("Only __call__ method is supported currently")
def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer":
"""Support numpy astype feature, for now it only accepts a dtype and no additional
parameters, *args and **kwargs are accepted for interface compatibility only
@@ -36,9 +52,66 @@ class NPTracer(BaseTracer):
arbitrary_func=normalized_numpy_dtype.type,
output_dtype=output_dtype,
)
output_tracer = NPTracer([self], traced_computation=traced_computation, output_index=0)
output_tracer = self.__class__(
[self], traced_computation=traced_computation, output_index=0
)
return output_tracer
@staticmethod
def get_tracing_func_for_np_ufunc(ufunc: numpy.ufunc) -> Callable:
"""Get the tracing function for a numpy ufunc
Args:
ufunc (numpy.ufunc): The numpy ufunc that will be traced
Raises:
NotImplementedError: Raised if the passed ufunc is not supported by NPTracer
Returns:
Callable: the tracing function that needs to be called to trace ufunc
"""
tracing_func = NPTracer.UFUNC_ROUTING.get(ufunc, None)
if tracing_func is None:
raise NotImplementedError(
f"NPTracer does not yet manage the following ufunc: {ufunc.__name__}"
)
return tracing_func
@staticmethod
def _manage_dtypes(ufunc: numpy.ufunc, *input_tracers: "NPTracer"):
output_dtypes = get_ufunc_numpy_output_dtype(
ufunc, [input_tracer.output.data_type for input_tracer in input_tracers]
)
common_output_dtypes = [
convert_numpy_dtype_to_common_dtype(dtype) for dtype in output_dtypes
]
return common_output_dtypes
def rint(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.rint
Returns:
NPTracer: The output NPTracer containing the traced function
"""
assert len(input_tracers) == 1
common_output_dtypes = self._manage_dtypes(numpy.rint, *input_tracers)
assert len(common_output_dtypes) == 1
traced_computation = ir.ArbitraryFunction(
input_base_value=input_tracers[0].output,
arbitrary_func=numpy.rint,
output_dtype=common_output_dtypes[0],
op_kwargs=kwargs,
)
output_tracer = self.__class__(
input_tracers, traced_computation=traced_computation, output_index=0
)
return output_tracer
UFUNC_ROUTING: Mapping[numpy.ufunc, Callable] = {
numpy.rint: rint,
}
def trace_numpy_function(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]

View File

@@ -201,3 +201,81 @@ def test_tracing_astype(
evaluated_output = node_results[output_node]
assert isinstance(evaluated_output, type(expected_output))
assert expected_output == evaluated_output
@pytest.mark.parametrize(
"function_to_trace,inputs,expected_output_node,expected_output_value",
[
# We cannot call trace_numpy_function on some numpy function as getting the signature for
# these functions fails, so we wrap it in a lambda
# pylint: disable=unnecessary-lambda
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Integer(7, is_signed=False))},
ir.ArbitraryFunction,
EncryptedValue(Float(64)),
),
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Integer(32, is_signed=True))},
ir.ArbitraryFunction,
EncryptedValue(Float(64)),
),
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Integer(64, is_signed=True))},
ir.ArbitraryFunction,
EncryptedValue(Float(64)),
),
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Integer(128, is_signed=True))},
ir.ArbitraryFunction,
None,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
pytest.param(
lambda x: numpy.rint(x),
{"x": EncryptedValue(Float(64))},
ir.ArbitraryFunction,
EncryptedValue(Float(64)),
),
# The next test case is only for coverage purposes, to trigger the unsupported method
# exception handling
pytest.param(
lambda x: numpy.add.reduce(x),
{"x": EncryptedValue(Integer(32, is_signed=True))},
None,
None,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
# pylint: enable=unnecessary-lambda
],
)
def test_trace_hnumpy_supported_ufuncs(
function_to_trace, inputs, expected_output_node, expected_output_value
):
"""Function to trace supported numpy ufuncs"""
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
assert len(op_graph.output_nodes) == 1
assert isinstance(op_graph.output_nodes[0], expected_output_node)
assert len(op_graph.output_nodes[0].outputs) == 1
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
@pytest.mark.parametrize(
"np_ufunc,expected_tracing_func",
[
pytest.param(numpy.rint, tracing.NPTracer.rint),
# There is a need to test the case where the function fails, I chose numpy.conjugate which
# works on complex types, as we don't talk about complex types for now this looks like a
# good long term candidate to check for an unsupported function
pytest.param(
numpy.conjugate, None, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError)
),
],
)
def test_nptracer_get_tracing_func_for_np_ufunc(np_ufunc, expected_tracing_func):
"""Test NPTracer get_tracing_func_for_np_ufunc"""
assert tracing.NPTracer.get_tracing_func_for_np_ufunc(np_ufunc) == expected_tracing_func