dev(numpy-dtypes): add some additional numpy dtype helper functions

- add function to convert a type from the project to a numpy dtype
- add function to manage numpy ufunc output dtypes
- add a check for integers to have positive bit_width
- add a check for floats to only accept 32 and 64 bits
This commit is contained in:
Arthur Meyre
2021-08-06 15:04:34 +02:00
parent e296e9667e
commit 19e68589d1
4 changed files with 118 additions and 3 deletions

View File

@@ -11,6 +11,7 @@ class Float(base.BaseDataType):
bit_width: int
def __init__(self, bit_width: int) -> None:
assert bit_width in (32, 64), "Only 32 and 64 bits floats are supported"
self.bit_width = bit_width
def __repr__(self) -> str:

View File

@@ -13,6 +13,7 @@ class Integer(base.BaseDataType):
is_signed: bool
def __init__(self, bit_width: int, is_signed: bool) -> None:
assert bit_width > 0, "bit_width must be > 0"
self.bit_width = bit_width
self.is_signed = is_signed

View File

@@ -1,11 +1,13 @@
"""File to hold code to manage package and numpy dtypes"""
from copy import deepcopy
from typing import List, Union
import numpy
from numpy.typing import DTypeLike
from ..common.data_types.base import BaseDataType
from ..common.data_types.dtypes_helpers import SUPPORTED_TYPES
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
@@ -18,7 +20,9 @@ NUMPY_TO_HDK_TYPE_MAPPING = {
numpy.dtype(numpy.float64): Float(64),
}
SUPPORTED_NUMPY_TYPES_SET = NUMPY_TO_HDK_TYPE_MAPPING.keys()
SUPPORTED_NUMPY_TYPES_SET = set(NUMPY_TO_HDK_TYPE_MAPPING.keys())
SUPPORTED_TYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_TYPES_SET))
def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType:
@@ -42,8 +46,91 @@ def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType:
raise ValueError(
f"Unsupported numpy type: {numpy_dtype} ({normalized_numpy_dtype}), "
f"supported numpy types: "
f"{', '.join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_TYPES_SET))}"
f"{SUPPORTED_TYPE_MSG_STRING}"
)
# deepcopy to avoid having the value from the dict modified
return deepcopy(corresponding_hdk_dtype)
def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype:
"""Convert a BaseDataType to corresponding numpy.dtype
Args:
common_dtype (BaseDataType): dtype to convert to numpy.dtype
Returns:
numpy.dtype: The resulting numpy.dtype
"""
assert isinstance(
common_dtype, SUPPORTED_TYPES
), f"Unsupported common_dtype: {type(common_dtype)}"
type_to_return: numpy.dtype
if isinstance(common_dtype, Float):
assert common_dtype.bit_width in (
32,
64,
), "Only converting Float(32) or Float(64) is supported"
type_to_return = (
numpy.dtype(numpy.float64)
if common_dtype.bit_width == 64
else numpy.dtype(numpy.float32)
)
elif isinstance(common_dtype, Integer):
signed = common_dtype.is_signed
if common_dtype.bit_width <= 32:
type_to_return = numpy.dtype(numpy.int32) if signed else numpy.dtype(numpy.uint32)
elif common_dtype.bit_width <= 64:
type_to_return = numpy.dtype(numpy.int64) if signed else numpy.dtype(numpy.uint64)
else:
raise NotImplementedError(
f"Conversion to numpy dtype only supports Integers with bit_width <= 64, "
f"got {common_dtype!r}"
)
return type_to_return
def get_ufunc_numpy_output_dtype(
ufunc: numpy.ufunc,
input_dtypes: Union[List[numpy.dtype], List[BaseDataType]],
) -> List[numpy.dtype]:
"""Function to record the output dtype of a numpy.ufunc given some input types
Args:
ufunc (numpy.ufunc): The numpy.ufunc whose output types need to be recorded
input_dtypes (Union[List[numpy.dtype], List[BaseDataType]]): Either numpy or common dtypes
in the same order as they will be used with the ufunc inputs
Returns:
List[numpy.dtype]: The ordered numpy dtypes of the ufunc outputs
"""
assert (
len(input_dtypes) == ufunc.nin
), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}"
input_dtypes = [
numpy.dtype(convert_common_dtype_to_numpy_dtype(dtype))
if not isinstance(dtype, numpy.dtype)
else dtype
for dtype in input_dtypes
]
# Store numpy old error settings and ignore all errors in this function
# We ignore errors as we may call functions with invalid inputs just to get the proper output
# dtypes
old_numpy_err_settings = numpy.seterr(all="ignore")
dummy_inputs = tuple(
dtype.type(1000.0 * numpy.random.random_sample()) for dtype in input_dtypes
)
outputs = ufunc(*dummy_inputs)
if not isinstance(outputs, tuple):
outputs = (outputs,)
# Restore numpy error settings
numpy.seterr(**old_numpy_err_settings)
return [output.dtype for output in outputs]

View File

@@ -5,7 +5,10 @@ import pytest
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.hnumpy.np_dtypes_helpers import convert_numpy_dtype_to_common_dtype
from hdk.hnumpy.np_dtypes_helpers import (
convert_common_dtype_to_numpy_dtype,
convert_numpy_dtype_to_common_dtype,
)
@pytest.mark.parametrize(
@@ -29,3 +32,26 @@ from hdk.hnumpy.np_dtypes_helpers import convert_numpy_dtype_to_common_dtype
def test_convert_numpy_dtype_to_common_dtype(numpy_dtype, expected_common_type):
"""Test function for convert_numpy_dtype_to_common_dtype"""
assert convert_numpy_dtype_to_common_dtype(numpy_dtype) == expected_common_type
@pytest.mark.parametrize(
"common_dtype,expected_numpy_dtype",
[
pytest.param(Integer(7, is_signed=False), numpy.uint32),
pytest.param(Integer(7, is_signed=True), numpy.int32),
pytest.param(Integer(32, is_signed=True), numpy.int32),
pytest.param(Integer(64, is_signed=True), numpy.int64),
pytest.param(Integer(32, is_signed=False), numpy.uint32),
pytest.param(Integer(64, is_signed=False), numpy.uint64),
pytest.param(Float(32), numpy.float32),
pytest.param(Float(64), numpy.float64),
pytest.param(
Integer(128, is_signed=True),
None,
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
),
],
)
def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype):
"""Test function for convert_common_dtype_to_numpy_dtype"""
assert expected_numpy_dtype == convert_common_dtype_to_numpy_dtype(common_dtype)