refactor: refactor ConstantInput to be flexible

- refactor to take a function to generate the propore BaseValue to store in
its output
- refactor BaseTracer to force inheriting tracers to indicate how to build
a ConstantInput tracer
- remove "as import" for intermediate in hnumpy/tracing.py
- update compile to manage python dtypes
This commit is contained in:
Arthur Meyre
2021-08-19 16:56:31 +02:00
parent c528d72e62
commit 9a0c108d4b
6 changed files with 95 additions and 22 deletions

View File

@@ -152,10 +152,13 @@ class ConstantInput(IntermediateNode):
def __init__(
self,
constant_data: Any,
get_base_value_for_data_func: Callable[
[Any], Callable[..., BaseValue]
] = get_base_value_for_python_constant_data,
) -> None:
super().__init__([])
base_value_class = get_base_value_for_python_constant_data(constant_data)
base_value_class = get_base_value_for_data_func(constant_data)
self._constant_data = constant_data
self.outputs = [base_value_class(is_encrypted=False)]

View File

@@ -36,6 +36,17 @@ class BaseTracer(ABC):
"""
return isinstance(other, self.__class__)
@abstractmethod
def _make_const_input_tracer(self, constant_data: Any) -> "BaseTracer":
"""Helper function to create a tracer for a constant input.
Args:
constant_data (Any): The constant to store.
Returns:
BaseTracer: The BaseTracer for that constant.
"""
def instantiate_output_tracers(
self,
inputs: Iterable[Union["BaseTracer", Any]],
@@ -55,7 +66,7 @@ class BaseTracer(ABC):
# For inputs which are actually constant, first convert into a tracer
def sanitize(inp):
if not isinstance(inp, BaseTracer):
return make_const_input_tracer(self.__class__, inp)
return self._make_const_input_tracer(inp)
return inp
sanitized_inputs = [sanitize(inp) for inp in inputs]
@@ -128,16 +139,3 @@ class BaseTracer(ABC):
# the order, we need to do as in __rmul__, ie mostly a copy of __mul__ +
# some changes
__rmul__ = __mul__
def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Any) -> BaseTracer:
"""Helper function to create a tracer for a constant input.
Args:
tracer_class (Type[BaseTracer]): the class of tracer to create a ConstantInput for
constant_data (Any): the constant
Returns:
BaseTracer: The BaseTracer for that constant
"""
return tracer_class([], ir.ConstantInput(constant_data), 0)

View File

@@ -17,6 +17,7 @@ from ..common.operator_graph import OPGraph
from ..common.optimization.topological import fuse_float_operations
from ..common.representation import intermediate as ir
from ..hnumpy.tracing import trace_numpy_function
from .np_dtypes_helpers import get_base_data_type_for_numpy_or_python_constant_data
def compile_numpy_function_into_op_graph(
@@ -74,7 +75,9 @@ def compile_numpy_function_into_op_graph(
node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset)
# Update the graph accordingly: after that, we have the compilable graph
op_graph.update_values_with_bounds(node_bounds)
op_graph.update_values_with_bounds(
node_bounds, get_base_data_type_for_numpy_or_python_constant_data
)
# Make sure the graph can be lowered to MLIR
if not is_graph_values_compatible_with_mlir(op_graph):

View File

@@ -1,15 +1,21 @@
"""File to hold code to manage package and numpy dtypes."""
from copy import deepcopy
from typing import Dict, List
from functools import partial
from typing import Any, Callable, Dict, List
import numpy
from numpy.typing import DTypeLike
from ..common.data_types.base import BaseDataType
from ..common.data_types.dtypes_helpers import BASE_DATA_TYPES
from ..common.data_types.dtypes_helpers import (
BASE_DATA_TYPES,
get_base_data_type_for_python_constant_data,
get_base_value_for_python_constant_data,
)
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
from ..common.data_types.values import BaseValue, ScalarValue
NUMPY_TO_HDK_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
@@ -92,6 +98,58 @@ def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.d
return type_to_return
def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType:
"""Helper function to determine the BaseDataType to hold the input constant data.
Args:
constant_data (Any): The constant data for which to determine the
corresponding BaseDataType.
Returns:
BaseDataType: The corresponding BaseDataType
"""
base_dtype: BaseDataType
assert isinstance(
constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
), f"Unsupported constant data of type {type(constant_data)}"
if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
base_dtype = convert_numpy_dtype_to_base_data_type(constant_data)
else:
base_dtype = get_base_data_type_for_python_constant_data(constant_data)
return base_dtype
def get_base_value_for_numpy_or_python_constant_data(
constant_data: Any,
) -> Callable[..., BaseValue]:
"""Helper function to determine the BaseValue and BaseDataType to hold the input constant data.
This function is able to handle numpy types
Args:
constant_data (Any): The constant data for which to determine the
corresponding BaseValue and BaseDataType.
Raises:
AssertionError: If `constant_data` is of an unsupported type.
Returns:
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when called
with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method).
"""
constant_data_value: Callable[..., BaseValue]
assert isinstance(
constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
), f"Unsupported constant data of type {type(constant_data)}"
base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data)
if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
constant_data_value = partial(ScalarValue, data_type=base_dtype)
else:
constant_data_value = get_base_value_for_python_constant_data(constant_data)
return constant_data_value
def get_ufunc_numpy_output_dtype(
ufunc: numpy.ufunc,
input_dtypes: List[BaseDataType],

View File

@@ -1,5 +1,6 @@
"""hnumpy tracing utilities."""
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict
import numpy
@@ -7,11 +8,12 @@ from numpy.typing import DTypeLike
from ..common.data_types import BaseValue
from ..common.operator_graph import OPGraph
from ..common.representation import intermediate as ir
from ..common.representation.intermediate import ArbitraryFunction, ConstantInput
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from .np_dtypes_helpers import (
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES,
convert_numpy_dtype_to_base_data_type,
get_base_value_for_numpy_or_python_constant_data,
get_ufunc_numpy_output_dtype,
)
@@ -19,6 +21,11 @@ SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple(
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES
)
NPConstantInput = partial(
ConstantInput,
get_base_value_for_data_func=get_base_value_for_numpy_or_python_constant_data,
)
class NPTracer(BaseTracer):
"""Tracer class for numpy operations."""
@@ -55,7 +62,7 @@ class NPTracer(BaseTracer):
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
traced_computation = ir.ArbitraryFunction(
traced_computation = ArbitraryFunction(
input_base_value=self.output,
arbitrary_func=normalized_numpy_dtype.type,
output_dtype=output_dtype,
@@ -91,6 +98,9 @@ class NPTracer(BaseTracer):
other, SUPPORTED_TYPES_FOR_TRACING
)
def _make_const_input_tracer(self, constant_data: Any) -> "NPTracer":
return self.__class__([], NPConstantInput(constant_data), 0)
@staticmethod
def _manage_dtypes(ufunc: numpy.ufunc, *input_tracers: "NPTracer"):
output_dtypes = get_ufunc_numpy_output_dtype(
@@ -111,7 +121,7 @@ class NPTracer(BaseTracer):
common_output_dtypes = self._manage_dtypes(numpy.rint, *input_tracers)
assert len(common_output_dtypes) == 1
traced_computation = ir.ArbitraryFunction(
traced_computation = ArbitraryFunction(
input_base_value=input_tracers[0].output,
arbitrary_func=numpy.rint,
output_dtype=common_output_dtypes[0],
@@ -133,7 +143,7 @@ class NPTracer(BaseTracer):
common_output_dtypes = self._manage_dtypes(numpy.sin, *input_tracers)
assert len(common_output_dtypes) == 1
traced_computation = ir.ArbitraryFunction(
traced_computation = ArbitraryFunction(
input_base_value=input_tracers[0].output,
arbitrary_func=numpy.sin,
output_dtype=common_output_dtypes[0],

View File

@@ -71,6 +71,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 42, ((0, 2),), ["x"]),
pytest.param(lambda x: x + numpy.int32(42), ((0, 2),), ["x"]),
pytest.param(lambda x: x * 2, ((0, 2),), ["x"]),
pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),