diff --git a/hdk/common/data_types/floats.py b/hdk/common/data_types/floats.py index 7021886e0..ed9969680 100644 --- a/hdk/common/data_types/floats.py +++ b/hdk/common/data_types/floats.py @@ -11,6 +11,7 @@ class Float(base.BaseDataType): bit_width: int def __init__(self, bit_width: int) -> None: + assert bit_width in (32, 64), "Only 32 and 64 bits floats are supported" self.bit_width = bit_width def __repr__(self) -> str: diff --git a/hdk/common/data_types/integers.py b/hdk/common/data_types/integers.py index f4753f09b..4ebf7e3d2 100644 --- a/hdk/common/data_types/integers.py +++ b/hdk/common/data_types/integers.py @@ -13,6 +13,7 @@ class Integer(base.BaseDataType): is_signed: bool def __init__(self, bit_width: int, is_signed: bool) -> None: + assert bit_width > 0, "bit_width must be > 0" self.bit_width = bit_width self.is_signed = is_signed diff --git a/hdk/hnumpy/np_dtypes_helpers.py b/hdk/hnumpy/np_dtypes_helpers.py index a926b188e..747a9c9c7 100644 --- a/hdk/hnumpy/np_dtypes_helpers.py +++ b/hdk/hnumpy/np_dtypes_helpers.py @@ -1,11 +1,13 @@ """File to hold code to manage package and numpy dtypes""" from copy import deepcopy +from typing import List, Union import numpy from numpy.typing import DTypeLike from ..common.data_types.base import BaseDataType +from ..common.data_types.dtypes_helpers import SUPPORTED_TYPES from ..common.data_types.floats import Float from ..common.data_types.integers import Integer @@ -18,7 +20,9 @@ NUMPY_TO_HDK_TYPE_MAPPING = { numpy.dtype(numpy.float64): Float(64), } -SUPPORTED_NUMPY_TYPES_SET = NUMPY_TO_HDK_TYPE_MAPPING.keys() +SUPPORTED_NUMPY_TYPES_SET = set(NUMPY_TO_HDK_TYPE_MAPPING.keys()) + +SUPPORTED_TYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_TYPES_SET)) def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType: @@ -42,8 +46,91 @@ def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType: raise ValueError( f"Unsupported numpy type: {numpy_dtype} ({normalized_numpy_dtype}), " f"supported numpy types: " - f"{', '.join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_TYPES_SET))}" + f"{SUPPORTED_TYPE_MSG_STRING}" ) # deepcopy to avoid having the value from the dict modified return deepcopy(corresponding_hdk_dtype) + + +def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype: + """Convert a BaseDataType to corresponding numpy.dtype + + Args: + common_dtype (BaseDataType): dtype to convert to numpy.dtype + + Returns: + numpy.dtype: The resulting numpy.dtype + """ + assert isinstance( + common_dtype, SUPPORTED_TYPES + ), f"Unsupported common_dtype: {type(common_dtype)}" + type_to_return: numpy.dtype + + if isinstance(common_dtype, Float): + assert common_dtype.bit_width in ( + 32, + 64, + ), "Only converting Float(32) or Float(64) is supported" + type_to_return = ( + numpy.dtype(numpy.float64) + if common_dtype.bit_width == 64 + else numpy.dtype(numpy.float32) + ) + elif isinstance(common_dtype, Integer): + signed = common_dtype.is_signed + if common_dtype.bit_width <= 32: + type_to_return = numpy.dtype(numpy.int32) if signed else numpy.dtype(numpy.uint32) + elif common_dtype.bit_width <= 64: + type_to_return = numpy.dtype(numpy.int64) if signed else numpy.dtype(numpy.uint64) + else: + raise NotImplementedError( + f"Conversion to numpy dtype only supports Integers with bit_width <= 64, " + f"got {common_dtype!r}" + ) + + return type_to_return + + +def get_ufunc_numpy_output_dtype( + ufunc: numpy.ufunc, + input_dtypes: Union[List[numpy.dtype], List[BaseDataType]], +) -> List[numpy.dtype]: + """Function to record the output dtype of a numpy.ufunc given some input types + + Args: + ufunc (numpy.ufunc): The numpy.ufunc whose output types need to be recorded + input_dtypes (Union[List[numpy.dtype], List[BaseDataType]]): Either numpy or common dtypes + in the same order as they will be used with the ufunc inputs + + Returns: + List[numpy.dtype]: The ordered numpy dtypes of the ufunc outputs + """ + assert ( + len(input_dtypes) == ufunc.nin + ), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}" + + input_dtypes = [ + numpy.dtype(convert_common_dtype_to_numpy_dtype(dtype)) + if not isinstance(dtype, numpy.dtype) + else dtype + for dtype in input_dtypes + ] + + # Store numpy old error settings and ignore all errors in this function + # We ignore errors as we may call functions with invalid inputs just to get the proper output + # dtypes + old_numpy_err_settings = numpy.seterr(all="ignore") + + dummy_inputs = tuple( + dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_dtypes + ) + + outputs = ufunc(*dummy_inputs) + if not isinstance(outputs, tuple): + outputs = (outputs,) + + # Restore numpy error settings + numpy.seterr(**old_numpy_err_settings) + + return [output.dtype for output in outputs] diff --git a/tests/hnumpy/test_np_dtypes_helpers.py b/tests/hnumpy/test_np_dtypes_helpers.py index ab2fd4201..269982390 100644 --- a/tests/hnumpy/test_np_dtypes_helpers.py +++ b/tests/hnumpy/test_np_dtypes_helpers.py @@ -5,7 +5,10 @@ import pytest from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer -from hdk.hnumpy.np_dtypes_helpers import convert_numpy_dtype_to_common_dtype +from hdk.hnumpy.np_dtypes_helpers import ( + convert_common_dtype_to_numpy_dtype, + convert_numpy_dtype_to_common_dtype, +) @pytest.mark.parametrize( @@ -29,3 +32,26 @@ from hdk.hnumpy.np_dtypes_helpers import convert_numpy_dtype_to_common_dtype def test_convert_numpy_dtype_to_common_dtype(numpy_dtype, expected_common_type): """Test function for convert_numpy_dtype_to_common_dtype""" assert convert_numpy_dtype_to_common_dtype(numpy_dtype) == expected_common_type + + +@pytest.mark.parametrize( + "common_dtype,expected_numpy_dtype", + [ + pytest.param(Integer(7, is_signed=False), numpy.uint32), + pytest.param(Integer(7, is_signed=True), numpy.int32), + pytest.param(Integer(32, is_signed=True), numpy.int32), + pytest.param(Integer(64, is_signed=True), numpy.int64), + pytest.param(Integer(32, is_signed=False), numpy.uint32), + pytest.param(Integer(64, is_signed=False), numpy.uint64), + pytest.param(Float(32), numpy.float32), + pytest.param(Float(64), numpy.float64), + pytest.param( + Integer(128, is_signed=True), + None, + marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), + ), + ], +) +def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype): + """Test function for convert_common_dtype_to_numpy_dtype""" + assert expected_numpy_dtype == convert_common_dtype_to_numpy_dtype(common_dtype)