From 9a0c108d4b316b4746b1d60b212995b52f5530da Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 19 Aug 2021 16:56:31 +0200 Subject: [PATCH] refactor: refactor ConstantInput to be flexible - refactor to take a function to generate the propore BaseValue to store in its output - refactor BaseTracer to force inheriting tracers to indicate how to build a ConstantInput tracer - remove "as import" for intermediate in hnumpy/tracing.py - update compile to manage python dtypes --- hdk/common/representation/intermediate.py | 5 +- hdk/common/tracing/base_tracer.py | 26 +++++----- hdk/hnumpy/compile.py | 5 +- hdk/hnumpy/np_dtypes_helpers.py | 62 ++++++++++++++++++++++- hdk/hnumpy/tracing.py | 18 +++++-- tests/hnumpy/test_compile.py | 1 + 6 files changed, 95 insertions(+), 22 deletions(-) diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index f591ca1cd..b2b26db65 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -152,10 +152,13 @@ class ConstantInput(IntermediateNode): def __init__( self, constant_data: Any, + get_base_value_for_data_func: Callable[ + [Any], Callable[..., BaseValue] + ] = get_base_value_for_python_constant_data, ) -> None: super().__init__([]) - base_value_class = get_base_value_for_python_constant_data(constant_data) + base_value_class = get_base_value_for_data_func(constant_data) self._constant_data = constant_data self.outputs = [base_value_class(is_encrypted=False)] diff --git a/hdk/common/tracing/base_tracer.py b/hdk/common/tracing/base_tracer.py index 83a52af07..e6774147a 100644 --- a/hdk/common/tracing/base_tracer.py +++ b/hdk/common/tracing/base_tracer.py @@ -36,6 +36,17 @@ class BaseTracer(ABC): """ return isinstance(other, self.__class__) + @abstractmethod + def _make_const_input_tracer(self, constant_data: Any) -> "BaseTracer": + """Helper function to create a tracer for a constant input. + + Args: + constant_data (Any): The constant to store. + + Returns: + BaseTracer: The BaseTracer for that constant. + """ + def instantiate_output_tracers( self, inputs: Iterable[Union["BaseTracer", Any]], @@ -55,7 +66,7 @@ class BaseTracer(ABC): # For inputs which are actually constant, first convert into a tracer def sanitize(inp): if not isinstance(inp, BaseTracer): - return make_const_input_tracer(self.__class__, inp) + return self._make_const_input_tracer(inp) return inp sanitized_inputs = [sanitize(inp) for inp in inputs] @@ -128,16 +139,3 @@ class BaseTracer(ABC): # the order, we need to do as in __rmul__, ie mostly a copy of __mul__ + # some changes __rmul__ = __mul__ - - -def make_const_input_tracer(tracer_class: Type[BaseTracer], constant_data: Any) -> BaseTracer: - """Helper function to create a tracer for a constant input. - - Args: - tracer_class (Type[BaseTracer]): the class of tracer to create a ConstantInput for - constant_data (Any): the constant - - Returns: - BaseTracer: The BaseTracer for that constant - """ - return tracer_class([], ir.ConstantInput(constant_data), 0) diff --git a/hdk/hnumpy/compile.py b/hdk/hnumpy/compile.py index 1f629b674..04f2fc4e1 100644 --- a/hdk/hnumpy/compile.py +++ b/hdk/hnumpy/compile.py @@ -17,6 +17,7 @@ from ..common.operator_graph import OPGraph from ..common.optimization.topological import fuse_float_operations from ..common.representation import intermediate as ir from ..hnumpy.tracing import trace_numpy_function +from .np_dtypes_helpers import get_base_data_type_for_numpy_or_python_constant_data def compile_numpy_function_into_op_graph( @@ -74,7 +75,9 @@ def compile_numpy_function_into_op_graph( node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset) # Update the graph accordingly: after that, we have the compilable graph - op_graph.update_values_with_bounds(node_bounds) + op_graph.update_values_with_bounds( + node_bounds, get_base_data_type_for_numpy_or_python_constant_data + ) # Make sure the graph can be lowered to MLIR if not is_graph_values_compatible_with_mlir(op_graph): diff --git a/hdk/hnumpy/np_dtypes_helpers.py b/hdk/hnumpy/np_dtypes_helpers.py index 4ae4d5c11..82d1ebb00 100644 --- a/hdk/hnumpy/np_dtypes_helpers.py +++ b/hdk/hnumpy/np_dtypes_helpers.py @@ -1,15 +1,21 @@ """File to hold code to manage package and numpy dtypes.""" from copy import deepcopy -from typing import Dict, List +from functools import partial +from typing import Any, Callable, Dict, List import numpy from numpy.typing import DTypeLike from ..common.data_types.base import BaseDataType -from ..common.data_types.dtypes_helpers import BASE_DATA_TYPES +from ..common.data_types.dtypes_helpers import ( + BASE_DATA_TYPES, + get_base_data_type_for_python_constant_data, + get_base_value_for_python_constant_data, +) from ..common.data_types.floats import Float from ..common.data_types.integers import Integer +from ..common.data_types.values import BaseValue, ScalarValue NUMPY_TO_HDK_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { numpy.dtype(numpy.int32): Integer(32, is_signed=True), @@ -92,6 +98,58 @@ def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.d return type_to_return +def get_base_data_type_for_numpy_or_python_constant_data(constant_data: Any) -> BaseDataType: + """Helper function to determine the BaseDataType to hold the input constant data. + + Args: + constant_data (Any): The constant data for which to determine the + corresponding BaseDataType. + + Returns: + BaseDataType: The corresponding BaseDataType + """ + base_dtype: BaseDataType + assert isinstance( + constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) + ), f"Unsupported constant data of type {type(constant_data)}" + if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES): + base_dtype = convert_numpy_dtype_to_base_data_type(constant_data) + else: + base_dtype = get_base_data_type_for_python_constant_data(constant_data) + return base_dtype + + +def get_base_value_for_numpy_or_python_constant_data( + constant_data: Any, +) -> Callable[..., BaseValue]: + """Helper function to determine the BaseValue and BaseDataType to hold the input constant data. + + This function is able to handle numpy types + + Args: + constant_data (Any): The constant data for which to determine the + corresponding BaseValue and BaseDataType. + + Raises: + AssertionError: If `constant_data` is of an unsupported type. + + Returns: + Callable[..., BaseValue]: A partial object that will return the proper BaseValue when called + with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method). + """ + constant_data_value: Callable[..., BaseValue] + assert isinstance( + constant_data, (int, float, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) + ), f"Unsupported constant data of type {type(constant_data)}" + + base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data) + if isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES): + constant_data_value = partial(ScalarValue, data_type=base_dtype) + else: + constant_data_value = get_base_value_for_python_constant_data(constant_data) + return constant_data_value + + def get_ufunc_numpy_output_dtype( ufunc: numpy.ufunc, input_dtypes: List[BaseDataType], diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index a01fb5bd7..075409c38 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -1,5 +1,6 @@ """hnumpy tracing utilities.""" from copy import deepcopy +from functools import partial from typing import Any, Callable, Dict import numpy @@ -7,11 +8,12 @@ from numpy.typing import DTypeLike from ..common.data_types import BaseValue from ..common.operator_graph import OPGraph -from ..common.representation import intermediate as ir +from ..common.representation.intermediate import ArbitraryFunction, ConstantInput from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters from .np_dtypes_helpers import ( SUPPORTED_NUMPY_DTYPES_CLASS_TYPES, convert_numpy_dtype_to_base_data_type, + get_base_value_for_numpy_or_python_constant_data, get_ufunc_numpy_output_dtype, ) @@ -19,6 +21,11 @@ SUPPORTED_TYPES_FOR_TRACING = (int, float, numpy.ndarray) + tuple( SUPPORTED_NUMPY_DTYPES_CLASS_TYPES ) +NPConstantInput = partial( + ConstantInput, + get_base_value_for_data_func=get_base_value_for_numpy_or_python_constant_data, +) + class NPTracer(BaseTracer): """Tracer class for numpy operations.""" @@ -55,7 +62,7 @@ class NPTracer(BaseTracer): normalized_numpy_dtype = numpy.dtype(numpy_dtype) output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype) - traced_computation = ir.ArbitraryFunction( + traced_computation = ArbitraryFunction( input_base_value=self.output, arbitrary_func=normalized_numpy_dtype.type, output_dtype=output_dtype, @@ -91,6 +98,9 @@ class NPTracer(BaseTracer): other, SUPPORTED_TYPES_FOR_TRACING ) + def _make_const_input_tracer(self, constant_data: Any) -> "NPTracer": + return self.__class__([], NPConstantInput(constant_data), 0) + @staticmethod def _manage_dtypes(ufunc: numpy.ufunc, *input_tracers: "NPTracer"): output_dtypes = get_ufunc_numpy_output_dtype( @@ -111,7 +121,7 @@ class NPTracer(BaseTracer): common_output_dtypes = self._manage_dtypes(numpy.rint, *input_tracers) assert len(common_output_dtypes) == 1 - traced_computation = ir.ArbitraryFunction( + traced_computation = ArbitraryFunction( input_base_value=input_tracers[0].output, arbitrary_func=numpy.rint, output_dtype=common_output_dtypes[0], @@ -133,7 +143,7 @@ class NPTracer(BaseTracer): common_output_dtypes = self._manage_dtypes(numpy.sin, *input_tracers) assert len(common_output_dtypes) == 1 - traced_computation = ir.ArbitraryFunction( + traced_computation = ArbitraryFunction( input_base_value=input_tracers[0].output, arbitrary_func=numpy.sin, output_dtype=common_output_dtypes[0], diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py index f03b0e892..8565b7334 100644 --- a/tests/hnumpy/test_compile.py +++ b/tests/hnumpy/test_compile.py @@ -71,6 +71,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n "function,input_ranges,list_of_arg_names", [ pytest.param(lambda x: x + 42, ((0, 2),), ["x"]), + pytest.param(lambda x: x + numpy.int32(42), ((0, 2),), ["x"]), pytest.param(lambda x: x * 2, ((0, 2),), ["x"]), pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]), pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),