diff --git a/hdk/hnumpy/np_dtypes_helpers.py b/hdk/hnumpy/np_dtypes_helpers.py index 1de2d6c4b..4ae4d5c11 100644 --- a/hdk/hnumpy/np_dtypes_helpers.py +++ b/hdk/hnumpy/np_dtypes_helpers.py @@ -1,7 +1,7 @@ """File to hold code to manage package and numpy dtypes.""" from copy import deepcopy -from typing import List +from typing import Dict, List import numpy from numpy.typing import DTypeLike @@ -11,7 +11,7 @@ from ..common.data_types.dtypes_helpers import BASE_DATA_TYPES from ..common.data_types.floats import Float from ..common.data_types.integers import Integer -NUMPY_TO_HDK_TYPE_MAPPING = { +NUMPY_TO_HDK_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { numpy.dtype(numpy.int32): Integer(32, is_signed=True), numpy.dtype(numpy.int64): Integer(64, is_signed=True), numpy.dtype(numpy.uint32): Integer(32, is_signed=False), @@ -20,13 +20,14 @@ NUMPY_TO_HDK_TYPE_MAPPING = { numpy.dtype(numpy.float64): Float(64), } -SUPPORTED_NUMPY_TYPES_SET = set(NUMPY_TO_HDK_TYPE_MAPPING.keys()) +SUPPORTED_NUMPY_DTYPES = tuple(NUMPY_TO_HDK_DTYPE_MAPPING) +SUPPORTED_NUMPY_DTYPES_CLASS_TYPES = tuple(dtype.type for dtype in NUMPY_TO_HDK_DTYPE_MAPPING) -SUPPORTED_TYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_TYPES_SET)) +SUPPORTED_DTYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_DTYPES)) -def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType: - """Helper function to get the corresponding type from a numpy dtype. +def convert_numpy_dtype_to_base_data_type(numpy_dtype: DTypeLike) -> BaseDataType: + """Helper function to get the corresponding BaseDataType from a numpy dtype. Args: numpy_dtype (DTypeLike): Any python object that can be translated to a numpy.dtype @@ -39,20 +40,20 @@ def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType: """ # Normalize numpy_dtype normalized_numpy_dtype = numpy.dtype(numpy_dtype) - corresponding_hdk_dtype = NUMPY_TO_HDK_TYPE_MAPPING.get(normalized_numpy_dtype, None) + corresponding_hdk_dtype = NUMPY_TO_HDK_DTYPE_MAPPING.get(normalized_numpy_dtype, None) if corresponding_hdk_dtype is None: raise ValueError( f"Unsupported numpy type: {numpy_dtype} ({normalized_numpy_dtype}), " f"supported numpy types: " - f"{SUPPORTED_TYPE_MSG_STRING}" + f"{SUPPORTED_DTYPE_MSG_STRING}" ) # deepcopy to avoid having the value from the dict modified return deepcopy(corresponding_hdk_dtype) -def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype: +def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype: """Convert a BaseDataType to corresponding numpy.dtype. Args: @@ -109,7 +110,7 @@ def get_ufunc_numpy_output_dtype( len(input_dtypes) == ufunc.nin ), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}" - input_numpy_dtypes = [convert_common_dtype_to_numpy_dtype(dtype) for dtype in input_dtypes] + input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes] # Store numpy old error settings and ignore all errors in this function # We ignore errors as we may call functions with invalid inputs just to get the proper output diff --git a/hdk/hnumpy/tracing.py b/hdk/hnumpy/tracing.py index 74416b2c7..ac9251929 100644 --- a/hdk/hnumpy/tracing.py +++ b/hdk/hnumpy/tracing.py @@ -10,7 +10,7 @@ from ..common.operator_graph import OPGraph from ..common.representation import intermediate as ir from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters from .np_dtypes_helpers import ( - convert_numpy_dtype_to_common_dtype, + convert_numpy_dtype_to_base_data_type, get_ufunc_numpy_output_dtype, ) @@ -49,7 +49,7 @@ class NPTracer(BaseTracer): ), f"astype currently only supports tracing without **kwargs, got {kwargs}" normalized_numpy_dtype = numpy.dtype(numpy_dtype) - output_dtype = convert_numpy_dtype_to_common_dtype(numpy_dtype) + output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype) traced_computation = ir.ArbitraryFunction( input_base_value=self.output, arbitrary_func=normalized_numpy_dtype.type, @@ -87,7 +87,7 @@ class NPTracer(BaseTracer): ufunc, [input_tracer.output.data_type for input_tracer in input_tracers] ) common_output_dtypes = [ - convert_numpy_dtype_to_common_dtype(dtype) for dtype in output_dtypes + convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes ] return common_output_dtypes diff --git a/tests/hnumpy/test_np_dtypes_helpers.py b/tests/hnumpy/test_np_dtypes_helpers.py index 269982390..e1433e19c 100644 --- a/tests/hnumpy/test_np_dtypes_helpers.py +++ b/tests/hnumpy/test_np_dtypes_helpers.py @@ -6,8 +6,8 @@ import pytest from hdk.common.data_types.floats import Float from hdk.common.data_types.integers import Integer from hdk.hnumpy.np_dtypes_helpers import ( - convert_common_dtype_to_numpy_dtype, - convert_numpy_dtype_to_common_dtype, + convert_base_data_type_to_numpy_dtype, + convert_numpy_dtype_to_base_data_type, ) @@ -29,9 +29,9 @@ from hdk.hnumpy.np_dtypes_helpers import ( pytest.param("complex64", None, marks=pytest.mark.xfail(strict=True, raises=ValueError)), ], ) -def test_convert_numpy_dtype_to_common_dtype(numpy_dtype, expected_common_type): - """Test function for convert_numpy_dtype_to_common_dtype""" - assert convert_numpy_dtype_to_common_dtype(numpy_dtype) == expected_common_type +def test_convert_numpy_dtype_to_base_data_type(numpy_dtype, expected_common_type): + """Test function for convert_numpy_dtype_to_base_data_type""" + assert convert_numpy_dtype_to_base_data_type(numpy_dtype) == expected_common_type @pytest.mark.parametrize( @@ -54,4 +54,4 @@ def test_convert_numpy_dtype_to_common_dtype(numpy_dtype, expected_common_type): ) def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype): """Test function for convert_common_dtype_to_numpy_dtype""" - assert expected_numpy_dtype == convert_common_dtype_to_numpy_dtype(common_dtype) + assert expected_numpy_dtype == convert_base_data_type_to_numpy_dtype(common_dtype)