feat: factorize the ufunc management and add lot of ufunc's

refs #126
This commit is contained in:
Benoit Chevallier-Mames
2021-08-31 17:28:52 +02:00
committed by Benoit Chevallier
parent b582e68cd0
commit e9c5ce27bb
2 changed files with 166 additions and 154 deletions

View File

@@ -1,7 +1,7 @@
"""hnumpy tracing utilities."""
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy
from numpy.typing import DTypeLike
@@ -43,7 +43,7 @@ class NPTracer(BaseTracer):
assert (
len(kwargs) == 0
), f"hnumpy does not support **kwargs currently for numpy ufuncs, ufunc: {ufunc}"
return tracing_func(self, *input_tracers, **kwargs)
return tracing_func(*input_tracers, **kwargs)
raise NotImplementedError("Only __call__ method is supported currently")
def __array_function__(self, func, _types, args, kwargs):
@@ -130,8 +130,9 @@ class NPTracer(BaseTracer):
]
return common_output_dtypes
@classmethod
def _unary_operator(
self, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs
) -> "NPTracer":
"""Function to trace an unary operator.
@@ -139,7 +140,7 @@ class NPTracer(BaseTracer):
NPTracer: The output NPTracer containing the traced function
"""
assert len(input_tracers) == 1
common_output_dtypes = self._manage_dtypes(unary_operator, *input_tracers)
common_output_dtypes = cls._manage_dtypes(unary_operator, *input_tracers)
assert len(common_output_dtypes) == 1
traced_computation = ArbitraryFunction(
@@ -149,93 +150,13 @@ class NPTracer(BaseTracer):
op_kwargs=deepcopy(kwargs),
op_name=unary_operator_string,
)
output_tracer = self.__class__(
output_tracer = cls(
input_tracers,
traced_computation=traced_computation,
output_index=0,
)
return output_tracer
def rint(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.rint.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.rint, "np.rint", *input_tracers, **kwargs)
def sin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.sin.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.sin, "np.sin", *input_tracers, **kwargs)
def cos(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.cos.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.cos, "np.cos", *input_tracers, **kwargs)
def tan(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.tan.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.tan, "np.tan", *input_tracers, **kwargs)
def arcsin(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.arcsin.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.arcsin, "np.arcsin", *input_tracers, **kwargs)
def arccos(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.arccos.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.arccos, "np.arccos", *input_tracers, **kwargs)
def arctan(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.arctan.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.arctan, "np.arctan", *input_tracers, **kwargs)
def exp(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.exp.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.exp, "np.exp", *input_tracers, **kwargs)
def expm1(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.expm1.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.expm1, "np.expm1", *input_tracers, **kwargs)
def exp2(self, *input_tracers: "NPTracer", **kwargs) -> "NPTracer":
"""Function to trace numpy.exp2.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self._unary_operator(numpy.exp2, "np.exp2", *input_tracers, **kwargs)
def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer":
"""Function to trace numpy.dot.
@@ -261,24 +182,124 @@ class NPTracer(BaseTracer):
)
return output_tracer
UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {
numpy.rint: rint,
numpy.sin: sin,
numpy.cos: cos,
numpy.tan: tan,
numpy.arcsin: arcsin,
numpy.arccos: arccos,
numpy.arctan: arctan,
numpy.exp: exp,
numpy.expm1: expm1,
numpy.exp2: exp2,
}
LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [
# The commented functions are functions which don't work for the moment, often
# if not always because they require more than a single argument
# numpy.absolute,
# numpy.add,
numpy.arccos,
numpy.arccosh,
numpy.arcsin,
numpy.arcsinh,
numpy.arctan,
# numpy.arctan2,
numpy.arctanh,
# numpy.bitwise_and,
# numpy.bitwise_or,
# numpy.bitwise_xor,
numpy.cbrt,
numpy.ceil,
# numpy.conjugate,
# numpy.copysign,
numpy.cos,
numpy.cosh,
numpy.deg2rad,
numpy.degrees,
# numpy.divmod,
# numpy.equal,
numpy.exp,
numpy.exp2,
numpy.expm1,
numpy.fabs,
# numpy.float_power,
numpy.floor,
# numpy.floor_divide,
# numpy.fmax,
# numpy.fmin,
# numpy.fmod,
# numpy.frexp,
# numpy.gcd,
# numpy.greater,
# numpy.greater_equal,
# numpy.heaviside,
# numpy.hypot,
# numpy.invert,
# numpy.isfinite,
# numpy.isinf,
# numpy.isnan,
# numpy.isnat,
# numpy.lcm,
# numpy.ldexp,
# numpy.left_shift,
# numpy.less,
# numpy.less_equal,
numpy.log,
numpy.log10,
numpy.log1p,
numpy.log2,
# numpy.logaddexp,
# numpy.logaddexp2,
# numpy.logical_and,
# numpy.logical_not,
# numpy.logical_or,
# numpy.logical_xor,
# numpy.matmul,
# numpy.maximum,
# numpy.minimum,
# numpy.modf,
# numpy.multiply,
# numpy.negative,
# numpy.nextafter,
# numpy.not_equal,
# numpy.positive,
# numpy.power,
numpy.rad2deg,
numpy.radians,
# numpy.reciprocal,
# numpy.remainder,
# numpy.right_shift,
numpy.rint,
# numpy.sign,
# numpy.signbit,
numpy.sin,
numpy.sinh,
numpy.spacing,
numpy.sqrt,
# numpy.square,
# numpy.subtract,
numpy.tan,
numpy.tanh,
# numpy.true_divide,
numpy.trunc,
]
# We build UFUNC_ROUTING dynamically after the creation of the class,
# because of some limits of python or our unability to do it properly
# in the class with techniques which are compatible with the different
# coding checks we use
UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {}
FUNC_ROUTING: Dict[Callable, Callable] = {
numpy.dot: dot,
}
def _get_fun(function: numpy.ufunc):
"""Helper function to wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING."""
# We have to access this method to be able to build NPTracer.UFUNC_ROUTING
# dynamically
# pylint: disable=protected-access
return lambda *input_tracers, **kwargs: NPTracer._unary_operator(
function, f"np.{function.__name__}", *input_tracers, **kwargs
)
# pylint: enable=protected-access
# We are populating NPTracer.UFUNC_ROUTING dynamically
NPTracer.UFUNC_ROUTING = {fun: _get_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC}
def trace_numpy_function(
function_to_trace: Callable, function_parameters: Dict[str, BaseValue]
) -> OPGraph:

View File

@@ -230,31 +230,6 @@ def test_tracing_astype(
assert expected_output == evaluated_output
@pytest.mark.parametrize(
"function_to_trace",
[
# We cannot call trace_numpy_function on some numpy function as getting the signature for
# these functions fails, so we wrap it in a lambda
# pylint: disable=unnecessary-lambda
pytest.param(lambda x: numpy.rint(x)),
pytest.param(lambda x: numpy.sin(x)),
pytest.param(lambda x: numpy.cos(x)),
pytest.param(lambda x: numpy.tan(x)),
pytest.param(lambda x: numpy.arcsin(x)),
pytest.param(lambda x: numpy.arccos(x)),
pytest.param(lambda x: numpy.arctan(x)),
pytest.param(lambda x: numpy.exp(x)),
pytest.param(lambda x: numpy.expm1(x)),
pytest.param(lambda x: numpy.exp2(x)),
# The next test case is only for coverage purposes, to trigger the unsupported method
# exception handling
pytest.param(
lambda x: numpy.add.reduce(x),
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
# pylint: enable=unnecessary-lambda
],
)
@pytest.mark.parametrize(
"inputs,expected_output_node,expected_output_value",
[
@@ -286,16 +261,40 @@ def test_tracing_astype(
),
],
)
def test_trace_hnumpy_supported_ufuncs(
function_to_trace, inputs, expected_output_node, expected_output_value
):
def test_trace_hnumpy_supported_ufuncs(inputs, expected_output_node, expected_output_value):
"""Function to trace supported numpy ufuncs"""
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
for function_to_trace_def in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
assert len(op_graph.output_nodes) == 1
assert isinstance(op_graph.output_nodes[0], expected_output_node)
assert len(op_graph.output_nodes[0].outputs) == 1
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it
# pylint: disable=unnecessary-lambda,cell-var-from-loop
function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731
# pylint: enable=unnecessary-lambda,cell-var-from-loop
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
assert len(op_graph.output_nodes) == 1
assert isinstance(op_graph.output_nodes[0], expected_output_node)
assert len(op_graph.output_nodes[0].outputs) == 1
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
def test_trace_hnumpy_ufuncs_not_supported():
"""Testing a failure case of trace_numpy_function"""
inputs = {"x": EncryptedScalar(Integer(128, is_signed=True))}
# We really need a lambda (because numpy functions are not playing
# nice with inspect.signature), but pylint and flake8 are not happy
# with it
# pylint: disable=unnecessary-lambda
function_to_trace = lambda x: numpy.add.reduce(x) # noqa: E731
# pylint: enable=unnecessary-lambda
with pytest.raises(NotImplementedError) as excinfo:
tracing.trace_numpy_function(function_to_trace, inputs)
assert "Only __call__ method is supported currently" in str(excinfo.value)
@pytest.mark.parametrize(
@@ -351,31 +350,23 @@ def test_trace_hnumpy_dot(function_to_trace, inputs, expected_output_node, expec
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
@pytest.mark.parametrize(
"np_function,expected_tracing_func",
[
pytest.param(numpy.rint, tracing.NPTracer.rint),
pytest.param(numpy.sin, tracing.NPTracer.sin),
pytest.param(numpy.cos, tracing.NPTracer.cos),
pytest.param(numpy.tan, tracing.NPTracer.tan),
pytest.param(numpy.arcsin, tracing.NPTracer.arcsin),
pytest.param(numpy.arccos, tracing.NPTracer.arccos),
pytest.param(numpy.arctan, tracing.NPTracer.arctan),
pytest.param(numpy.exp, tracing.NPTracer.exp),
pytest.param(numpy.expm1, tracing.NPTracer.expm1),
pytest.param(numpy.exp2, tracing.NPTracer.exp2),
pytest.param(numpy.dot, tracing.NPTracer.dot),
# There is a need to test the case where the function fails, I chose numpy.conjugate which
# works on complex types, as we don't talk about complex types for now this looks like a
# good long term candidate to check for an unsupported function
pytest.param(
numpy.conjugate, None, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError)
),
],
)
def test_nptracer_get_tracing_func_for_np_functions(np_function, expected_tracing_func):
def test_nptracer_get_tracing_func_for_np_functions():
"""Test NPTracer get_tracing_func_for_np_function"""
assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
for np_function in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC:
expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function]
assert (
tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func
)
def test_nptracer_get_tracing_func_for_np_functions_not_implemented():
"""Check NPTracer in case of not-implemented function"""
with pytest.raises(NotImplementedError) as excinfo:
tracing.NPTracer.get_tracing_func_for_np_function(numpy.conjugate)
assert "NPTracer does not yet manage the following func: conjugate" in str(excinfo.value)
@pytest.mark.parametrize(