refactor: refactor np_dtype_helpers.py

- rename convert_numpy_dtype_to_common_dtype to
convert_numpy_dtype_to_base_data_type
- change names of constants to be clearer
This commit is contained in:
Arthur Meyre
2021-08-19 15:41:30 +02:00
parent 4f103e604a
commit a4181afe4d
3 changed files with 20 additions and 19 deletions

View File

@@ -1,7 +1,7 @@
"""File to hold code to manage package and numpy dtypes."""
from copy import deepcopy
from typing import List
from typing import Dict, List
import numpy
from numpy.typing import DTypeLike
@@ -11,7 +11,7 @@ from ..common.data_types.dtypes_helpers import BASE_DATA_TYPES
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
NUMPY_TO_HDK_TYPE_MAPPING = {
NUMPY_TO_HDK_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
numpy.dtype(numpy.int64): Integer(64, is_signed=True),
numpy.dtype(numpy.uint32): Integer(32, is_signed=False),
@@ -20,13 +20,14 @@ NUMPY_TO_HDK_TYPE_MAPPING = {
numpy.dtype(numpy.float64): Float(64),
}
SUPPORTED_NUMPY_TYPES_SET = set(NUMPY_TO_HDK_TYPE_MAPPING.keys())
SUPPORTED_NUMPY_DTYPES = tuple(NUMPY_TO_HDK_DTYPE_MAPPING)
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES = tuple(dtype.type for dtype in NUMPY_TO_HDK_DTYPE_MAPPING)
SUPPORTED_TYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_TYPES_SET))
SUPPORTED_DTYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_DTYPES))
def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType:
"""Helper function to get the corresponding type from a numpy dtype.
def convert_numpy_dtype_to_base_data_type(numpy_dtype: DTypeLike) -> BaseDataType:
"""Helper function to get the corresponding BaseDataType from a numpy dtype.
Args:
numpy_dtype (DTypeLike): Any python object that can be translated to a numpy.dtype
@@ -39,20 +40,20 @@ def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType:
"""
# Normalize numpy_dtype
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
corresponding_hdk_dtype = NUMPY_TO_HDK_TYPE_MAPPING.get(normalized_numpy_dtype, None)
corresponding_hdk_dtype = NUMPY_TO_HDK_DTYPE_MAPPING.get(normalized_numpy_dtype, None)
if corresponding_hdk_dtype is None:
raise ValueError(
f"Unsupported numpy type: {numpy_dtype} ({normalized_numpy_dtype}), "
f"supported numpy types: "
f"{SUPPORTED_TYPE_MSG_STRING}"
f"{SUPPORTED_DTYPE_MSG_STRING}"
)
# deepcopy to avoid having the value from the dict modified
return deepcopy(corresponding_hdk_dtype)
def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype:
def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype:
"""Convert a BaseDataType to corresponding numpy.dtype.
Args:
@@ -109,7 +110,7 @@ def get_ufunc_numpy_output_dtype(
len(input_dtypes) == ufunc.nin
), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}"
input_numpy_dtypes = [convert_common_dtype_to_numpy_dtype(dtype) for dtype in input_dtypes]
input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes]
# Store numpy old error settings and ignore all errors in this function
# We ignore errors as we may call functions with invalid inputs just to get the proper output

View File

@@ -10,7 +10,7 @@ from ..common.operator_graph import OPGraph
from ..common.representation import intermediate as ir
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
from .np_dtypes_helpers import (
convert_numpy_dtype_to_common_dtype,
convert_numpy_dtype_to_base_data_type,
get_ufunc_numpy_output_dtype,
)
@@ -49,7 +49,7 @@ class NPTracer(BaseTracer):
), f"astype currently only supports tracing without **kwargs, got {kwargs}"
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
output_dtype = convert_numpy_dtype_to_common_dtype(numpy_dtype)
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
traced_computation = ir.ArbitraryFunction(
input_base_value=self.output,
arbitrary_func=normalized_numpy_dtype.type,
@@ -87,7 +87,7 @@ class NPTracer(BaseTracer):
ufunc, [input_tracer.output.data_type for input_tracer in input_tracers]
)
common_output_dtypes = [
convert_numpy_dtype_to_common_dtype(dtype) for dtype in output_dtypes
convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes
]
return common_output_dtypes

View File

@@ -6,8 +6,8 @@ import pytest
from hdk.common.data_types.floats import Float
from hdk.common.data_types.integers import Integer
from hdk.hnumpy.np_dtypes_helpers import (
convert_common_dtype_to_numpy_dtype,
convert_numpy_dtype_to_common_dtype,
convert_base_data_type_to_numpy_dtype,
convert_numpy_dtype_to_base_data_type,
)
@@ -29,9 +29,9 @@ from hdk.hnumpy.np_dtypes_helpers import (
pytest.param("complex64", None, marks=pytest.mark.xfail(strict=True, raises=ValueError)),
],
)
def test_convert_numpy_dtype_to_common_dtype(numpy_dtype, expected_common_type):
"""Test function for convert_numpy_dtype_to_common_dtype"""
assert convert_numpy_dtype_to_common_dtype(numpy_dtype) == expected_common_type
def test_convert_numpy_dtype_to_base_data_type(numpy_dtype, expected_common_type):
"""Test function for convert_numpy_dtype_to_base_data_type"""
assert convert_numpy_dtype_to_base_data_type(numpy_dtype) == expected_common_type
@pytest.mark.parametrize(
@@ -54,4 +54,4 @@ def test_convert_numpy_dtype_to_common_dtype(numpy_dtype, expected_common_type):
)
def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype):
"""Test function for convert_common_dtype_to_numpy_dtype"""
assert expected_numpy_dtype == convert_common_dtype_to_numpy_dtype(common_dtype)
assert expected_numpy_dtype == convert_base_data_type_to_numpy_dtype(common_dtype)