mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: refactor np_dtype_helpers.py
- rename convert_numpy_dtype_to_common_dtype to convert_numpy_dtype_to_base_data_type - change names of constants to be clearer
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""File to hold code to manage package and numpy dtypes."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy
|
||||
from numpy.typing import DTypeLike
|
||||
@@ -11,7 +11,7 @@ from ..common.data_types.dtypes_helpers import BASE_DATA_TYPES
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
|
||||
NUMPY_TO_HDK_TYPE_MAPPING = {
|
||||
NUMPY_TO_HDK_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
|
||||
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
|
||||
numpy.dtype(numpy.int64): Integer(64, is_signed=True),
|
||||
numpy.dtype(numpy.uint32): Integer(32, is_signed=False),
|
||||
@@ -20,13 +20,14 @@ NUMPY_TO_HDK_TYPE_MAPPING = {
|
||||
numpy.dtype(numpy.float64): Float(64),
|
||||
}
|
||||
|
||||
SUPPORTED_NUMPY_TYPES_SET = set(NUMPY_TO_HDK_TYPE_MAPPING.keys())
|
||||
SUPPORTED_NUMPY_DTYPES = tuple(NUMPY_TO_HDK_DTYPE_MAPPING)
|
||||
SUPPORTED_NUMPY_DTYPES_CLASS_TYPES = tuple(dtype.type for dtype in NUMPY_TO_HDK_DTYPE_MAPPING)
|
||||
|
||||
SUPPORTED_TYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_TYPES_SET))
|
||||
SUPPORTED_DTYPE_MSG_STRING = ", ".join(sorted(str(dtype) for dtype in SUPPORTED_NUMPY_DTYPES))
|
||||
|
||||
|
||||
def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType:
|
||||
"""Helper function to get the corresponding type from a numpy dtype.
|
||||
def convert_numpy_dtype_to_base_data_type(numpy_dtype: DTypeLike) -> BaseDataType:
|
||||
"""Helper function to get the corresponding BaseDataType from a numpy dtype.
|
||||
|
||||
Args:
|
||||
numpy_dtype (DTypeLike): Any python object that can be translated to a numpy.dtype
|
||||
@@ -39,20 +40,20 @@ def convert_numpy_dtype_to_common_dtype(numpy_dtype: DTypeLike) -> BaseDataType:
|
||||
"""
|
||||
# Normalize numpy_dtype
|
||||
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
|
||||
corresponding_hdk_dtype = NUMPY_TO_HDK_TYPE_MAPPING.get(normalized_numpy_dtype, None)
|
||||
corresponding_hdk_dtype = NUMPY_TO_HDK_DTYPE_MAPPING.get(normalized_numpy_dtype, None)
|
||||
|
||||
if corresponding_hdk_dtype is None:
|
||||
raise ValueError(
|
||||
f"Unsupported numpy type: {numpy_dtype} ({normalized_numpy_dtype}), "
|
||||
f"supported numpy types: "
|
||||
f"{SUPPORTED_TYPE_MSG_STRING}"
|
||||
f"{SUPPORTED_DTYPE_MSG_STRING}"
|
||||
)
|
||||
|
||||
# deepcopy to avoid having the value from the dict modified
|
||||
return deepcopy(corresponding_hdk_dtype)
|
||||
|
||||
|
||||
def convert_common_dtype_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype:
|
||||
def convert_base_data_type_to_numpy_dtype(common_dtype: BaseDataType) -> numpy.dtype:
|
||||
"""Convert a BaseDataType to corresponding numpy.dtype.
|
||||
|
||||
Args:
|
||||
@@ -109,7 +110,7 @@ def get_ufunc_numpy_output_dtype(
|
||||
len(input_dtypes) == ufunc.nin
|
||||
), f"Expected {ufunc.nin} types, got {len(input_dtypes)}: {input_dtypes}"
|
||||
|
||||
input_numpy_dtypes = [convert_common_dtype_to_numpy_dtype(dtype) for dtype in input_dtypes]
|
||||
input_numpy_dtypes = [convert_base_data_type_to_numpy_dtype(dtype) for dtype in input_dtypes]
|
||||
|
||||
# Store numpy old error settings and ignore all errors in this function
|
||||
# We ignore errors as we may call functions with invalid inputs just to get the proper output
|
||||
|
||||
@@ -10,7 +10,7 @@ from ..common.operator_graph import OPGraph
|
||||
from ..common.representation import intermediate as ir
|
||||
from ..common.tracing import BaseTracer, make_input_tracers, prepare_function_parameters
|
||||
from .np_dtypes_helpers import (
|
||||
convert_numpy_dtype_to_common_dtype,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_ufunc_numpy_output_dtype,
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ class NPTracer(BaseTracer):
|
||||
), f"astype currently only supports tracing without **kwargs, got {kwargs}"
|
||||
|
||||
normalized_numpy_dtype = numpy.dtype(numpy_dtype)
|
||||
output_dtype = convert_numpy_dtype_to_common_dtype(numpy_dtype)
|
||||
output_dtype = convert_numpy_dtype_to_base_data_type(numpy_dtype)
|
||||
traced_computation = ir.ArbitraryFunction(
|
||||
input_base_value=self.output,
|
||||
arbitrary_func=normalized_numpy_dtype.type,
|
||||
@@ -87,7 +87,7 @@ class NPTracer(BaseTracer):
|
||||
ufunc, [input_tracer.output.data_type for input_tracer in input_tracers]
|
||||
)
|
||||
common_output_dtypes = [
|
||||
convert_numpy_dtype_to_common_dtype(dtype) for dtype in output_dtypes
|
||||
convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes
|
||||
]
|
||||
return common_output_dtypes
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import pytest
|
||||
from hdk.common.data_types.floats import Float
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.hnumpy.np_dtypes_helpers import (
|
||||
convert_common_dtype_to_numpy_dtype,
|
||||
convert_numpy_dtype_to_common_dtype,
|
||||
convert_base_data_type_to_numpy_dtype,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -29,9 +29,9 @@ from hdk.hnumpy.np_dtypes_helpers import (
|
||||
pytest.param("complex64", None, marks=pytest.mark.xfail(strict=True, raises=ValueError)),
|
||||
],
|
||||
)
|
||||
def test_convert_numpy_dtype_to_common_dtype(numpy_dtype, expected_common_type):
|
||||
"""Test function for convert_numpy_dtype_to_common_dtype"""
|
||||
assert convert_numpy_dtype_to_common_dtype(numpy_dtype) == expected_common_type
|
||||
def test_convert_numpy_dtype_to_base_data_type(numpy_dtype, expected_common_type):
|
||||
"""Test function for convert_numpy_dtype_to_base_data_type"""
|
||||
assert convert_numpy_dtype_to_base_data_type(numpy_dtype) == expected_common_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -54,4 +54,4 @@ def test_convert_numpy_dtype_to_common_dtype(numpy_dtype, expected_common_type):
|
||||
)
|
||||
def test_convert_common_dtype_to_numpy_dtype(common_dtype, expected_numpy_dtype):
|
||||
"""Test function for convert_common_dtype_to_numpy_dtype"""
|
||||
assert expected_numpy_dtype == convert_common_dtype_to_numpy_dtype(common_dtype)
|
||||
assert expected_numpy_dtype == convert_base_data_type_to_numpy_dtype(common_dtype)
|
||||
|
||||
Reference in New Issue
Block a user