diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 2a6f17dd0..00482d79e 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -2,18 +2,10 @@ from copy import deepcopy from functools import partial -from typing import Callable, List, Union, cast +from typing import Callable, Union, cast from ..debugging.custom_assert import custom_assert -from ..values import ( - BaseValue, - ClearScalar, - ClearTensor, - EncryptedScalar, - EncryptedTensor, - ScalarValue, - TensorValue, -) +from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue from .base import BaseDataType from .floats import Float from .integers import Integer, get_bits_to_represent_value_as_integer @@ -24,25 +16,25 @@ BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool: - """Check that a value is an encrypted ScalarValue of type Integer. + """Check that a value is an encrypted scalar of type Integer. Args: value_to_check (BaseValue): The value to check Returns: - bool: True if the passed value_to_check is an encrypted ScalarValue of type Integer + bool: True if the passed value_to_check is an encrypted scalar of type Integer """ return value_is_scalar_integer(value_to_check) and value_to_check.is_encrypted def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> bool: - """Check that a value is an encrypted ScalarValue of type unsigned Integer. + """Check that a value is an encrypted scalar of type unsigned Integer. Args: value_to_check (BaseValue): The value to check Returns: - bool: True if the passed value_to_check is an encrypted ScalarValue of type Integer and + bool: True if the passed value_to_check is an encrypted scalar of type Integer and unsigned """ return ( @@ -52,28 +44,30 @@ def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> boo def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool: - """Check that a value is a clear ScalarValue of type Integer. + """Check that a value is a clear scalar of type Integer. Args: value_to_check (BaseValue): The value to check Returns: - bool: True if the passed value_to_check is a clear ScalarValue of type Integer + bool: True if the passed value_to_check is a clear scalar of type Integer """ return value_is_scalar_integer(value_to_check) and value_to_check.is_clear def value_is_scalar_integer(value_to_check: BaseValue) -> bool: - """Check that a value is a ScalarValue of type Integer. + """Check that a value is a scalar of type Integer. Args: value_to_check (BaseValue): The value to check Returns: - bool: True if the passed value_to_check is a ScalarValue of type Integer + bool: True if the passed value_to_check is a scalar of type Integer """ - return isinstance(value_to_check, ScalarValue) and isinstance( - value_to_check.dtype, INTEGER_TYPES + return ( + isinstance(value_to_check, TensorValue) + and value_to_check.is_scalar + and isinstance(value_to_check.dtype, INTEGER_TYPES) ) @@ -126,8 +120,10 @@ def value_is_tensor_integer(value_to_check: BaseValue) -> bool: Returns: bool: True if the passed value_to_check is a TensorValue of type Integer """ - return isinstance(value_to_check, TensorValue) and isinstance( - value_to_check.dtype, INTEGER_TYPES + return ( + isinstance(value_to_check, TensorValue) + and not value_to_check.is_scalar + and isinstance(value_to_check.dtype, INTEGER_TYPES) ) @@ -190,43 +186,6 @@ def find_type_to_hold_both_lossy( return type_to_return -def mix_scalar_values_determine_holding_dtype( - value1: ScalarValue, - value2: ScalarValue, -) -> ScalarValue: - """Return mixed ScalarValue with data type able to hold both value1 and value2 dtypes. - - Returns a ScalarValue that would result from computation on both value1 and value2 while - determining the data type able to hold both value1 and value2 data type (this can be lossy - with floats). - - Args: - value1 (ScalarValue): first ScalarValue to mix. - value2 (ScalarValue): second ScalarValue to mix. - - Returns: - ScalarValue: The resulting mixed ScalarValue with data type able to hold both value1 and - value2 dtypes. - """ - - custom_assert( - isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue" - ) - custom_assert( - isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue" - ) - - holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype) - mixed_value: ScalarValue - - if value1.is_encrypted or value2.is_encrypted: - mixed_value = EncryptedScalar(holding_type) - else: - mixed_value = ClearScalar(holding_type) - - return mixed_value - - def mix_tensor_values_determine_holding_dtype( value1: TensorValue, value2: TensorValue, @@ -284,7 +243,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> value2 (BaseValue): second BaseValue to mix. Raises: - ValueError: raised if the BaseValue is not one of (ScalarValue, TensorValue) + ValueError: raised if the BaseValue is not one of (TensorValue) Returns: BaseValue: The resulting mixed BaseValue with data type able to hold both value1 and value2 @@ -296,8 +255,6 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}", ) - if isinstance(value1, ScalarValue) and isinstance(value2, ScalarValue): - return mix_scalar_values_determine_holding_dtype(value1, value2) if isinstance(value1, TensorValue) and isinstance(value2, TensorValue): return mix_tensor_values_determine_holding_dtype(value1, value2) @@ -306,9 +263,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) -> ) -def get_base_data_type_for_python_constant_data( - constant_data: Union[int, float, List[int], List[float]] -) -> BaseDataType: +def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType: """Determine the BaseDataType to hold the input constant data. Args: @@ -320,28 +275,23 @@ def get_base_data_type_for_python_constant_data( """ constant_data_type: BaseDataType custom_assert( - isinstance(constant_data, (int, float, list)), + isinstance(constant_data, (int, float)), f"Unsupported constant data of type {type(constant_data)}", ) - if isinstance(constant_data, list): - custom_assert(len(constant_data) > 0, "Data type of empty list cannot be detected") - constant_data_type = get_base_data_type_for_python_constant_data(constant_data[0]) - for value in constant_data: - other_data_type = get_base_data_type_for_python_constant_data(value) - constant_data_type = find_type_to_hold_both_lossy(constant_data_type, other_data_type) - elif isinstance(constant_data, int): + if isinstance(constant_data, int): is_signed = constant_data < 0 constant_data_type = Integer( get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed ) elif isinstance(constant_data, float): constant_data_type = Float(64) + return constant_data_type def get_base_value_for_python_constant_data( - constant_data: Union[int, float, List[int], List[float]] + constant_data: Union[int, float] ) -> Callable[..., BaseValue]: """Wrap the BaseDataType to hold the input constant data in BaseValue partial. @@ -349,8 +299,8 @@ def get_base_value_for_python_constant_data( by calling it with the proper arguments forwarded to the BaseValue `__init__` function Args: - constant_data (Union[int, float, List[int], List[float]]): The constant data - for which to determine the corresponding Value. + constant_data (Union[int, float]): The constant data for which to determine the + corresponding Value. Returns: Callable[..., BaseValue]: A partial object that will return the proper BaseValue when @@ -358,14 +308,8 @@ def get_base_value_for_python_constant_data( method). """ - if isinstance(constant_data, list): - assert len(constant_data) > 0 - constant_shape = (len(constant_data),) - constant_data_type = get_base_data_type_for_python_constant_data(constant_data) - return partial(TensorValue, dtype=constant_data_type, shape=constant_shape) - constant_data_type = get_base_data_type_for_python_constant_data(constant_data) - return partial(ScalarValue, dtype=constant_data_type) + return partial(TensorValue, dtype=constant_data_type, shape=()) def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]): diff --git a/concrete/common/mlir/mlir_converter.py b/concrete/common/mlir/mlir_converter.py index 177c538dd..55af22122 100644 --- a/concrete/common/mlir/mlir_converter.py +++ b/concrete/common/mlir/mlir_converter.py @@ -7,7 +7,6 @@ import zamalang from mlir.dialects import builtin from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module, RankedTensorType from mlir.ir import Type as MLIRType -from mlir.ir import UnrankedTensorType from zamalang.dialects import hlfhe from .. import values @@ -64,10 +63,7 @@ class MLIRConverter: MLIRType: corresponding MLIR type """ element_type = self._get_scalar_integer_type(bit_width, is_encrypted, is_signed) - if len(shape): # randked tensor - return RankedTensorType.get(shape, element_type) - # unranked tensor - return UnrankedTensorType.get(element_type) + return RankedTensorType.get(shape, element_type) def _get_scalar_integer_type( self, bit_width: int, is_encrypted: bool, is_signed: bool diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 4ee420e72..da66a20d5 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -9,7 +9,7 @@ from loguru import logger from ..data_types.base import BaseDataType from ..data_types.dtypes_helpers import ( get_base_value_for_python_constant_data, - mix_scalar_values_determine_holding_dtype, + mix_values_determine_holding_dtype, ) from ..data_types.integers import Integer from ..debugging.custom_assert import custom_assert @@ -43,7 +43,7 @@ class IntermediateNode(ABC): def _init_binary( self, inputs: Iterable[BaseValue], - mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype, + mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype, **_kwargs, # Required to conform to __init__ typing ) -> None: """__init__ for a binary operation, ie two inputs.""" @@ -221,7 +221,11 @@ class ArbitraryFunction(IntermediateNode): self.arbitrary_func = arbitrary_func self.op_args = op_args if op_args is not None else () self.op_kwargs = op_kwargs if op_kwargs is not None else {} - self.outputs = [input_base_value.__class__(output_dtype, input_base_value.is_encrypted)] + + output = deepcopy(input_base_value) + output.dtype = output_dtype + self.outputs = [output] + self.op_name = op_name if op_name is not None else self.__class__.__name__ def evaluate(self, inputs: Dict[int, Any]) -> Any: diff --git a/concrete/common/values/__init__.py b/concrete/common/values/__init__.py index 34bc45927..4a1e3290c 100644 --- a/concrete/common/values/__init__.py +++ b/concrete/common/values/__init__.py @@ -1,6 +1,5 @@ """Module for value structures.""" -from . import scalars, tensors +from . import tensors from .base import BaseValue -from .scalars import ClearScalar, EncryptedScalar, ScalarValue -from .tensors import ClearTensor, EncryptedTensor, TensorValue +from .tensors import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue diff --git a/concrete/common/values/scalars.py b/concrete/common/values/scalars.py deleted file mode 100644 index 66be4eb26..000000000 --- a/concrete/common/values/scalars.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Module that defines the scalar values in a program.""" - -from typing import Tuple - -from ..data_types.base import BaseDataType -from .base import BaseValue - - -class ScalarValue(BaseValue): - """Class representing a scalar value.""" - - def __eq__(self, other: object) -> bool: - return BaseValue.__eq__(self, other) - - def __str__(self) -> str: # pragma: no cover - encrypted_str = "Encrypted" if self._is_encrypted else "Clear" - return f"{encrypted_str}Scalar<{self.dtype!r}>" - - @property - def shape(self) -> Tuple[int, ...]: - """Return the ScalarValue shape property. - - Returns: - Tuple[int, ...]: The ScalarValue shape which is `()`. - """ - return () - - -def make_clear_scalar(dtype: BaseDataType) -> ScalarValue: - """Create a clear ScalarValue. - - Args: - dtype (BaseDataType): The data type for the value. - - Returns: - ScalarValue: The corresponding ScalarValue. - """ - return ScalarValue(dtype=dtype, is_encrypted=False) - - -def make_encrypted_scalar(dtype: BaseDataType) -> ScalarValue: - """Create an encrypted ScalarValue. - - Args: - dtype (BaseDataType): The data type for the value. - - Returns: - ScalarValue: The corresponding ScalarValue. - """ - return ScalarValue(dtype=dtype, is_encrypted=True) - - -ClearScalar = make_clear_scalar -EncryptedScalar = make_encrypted_scalar diff --git a/concrete/common/values/tensors.py b/concrete/common/values/tensors.py index ded52565e..5a6eb8185 100644 --- a/concrete/common/values/tensors.py +++ b/concrete/common/values/tensors.py @@ -1,7 +1,7 @@ """Module that defines the tensor values in a program.""" from math import prod -from typing import Optional, Tuple +from typing import Tuple from ..data_types.base import BaseDataType from .base import BaseValue @@ -18,13 +18,13 @@ class TensorValue(BaseValue): self, dtype: BaseDataType, is_encrypted: bool, - shape: Optional[Tuple[int, ...]] = None, - ) -> None: + shape: Tuple[int, ...], + ): super().__init__(dtype, is_encrypted) - # Managing tensors as in numpy, no shape or () is treated as a 0-D array of size 1 - self._shape = shape if shape is not None else () + # Managing tensors as in numpy, shape of () means the value is scalar + self._shape = shape self._ndim = len(self._shape) - self._size = prod(self._shape) if self._shape else 1 + self._size = prod(self._shape) if self._shape != () else 1 def __eq__(self, other: object) -> bool: return ( @@ -37,7 +37,9 @@ class TensorValue(BaseValue): def __str__(self) -> str: encrypted_str = "Encrypted" if self._is_encrypted else "Clear" - return f"{encrypted_str}Tensor<{str(self.dtype)}, shape={self.shape}>" + tensor_or_scalar_str = "Scalar" if self.is_scalar else "Tensor" + shape_str = f", shape={self.shape}" if self.shape != () else "" + return f"{encrypted_str}{tensor_or_scalar_str}<{str(self.dtype)}{shape_str}>" @property def shape(self) -> Tuple[int, ...]: @@ -66,10 +68,19 @@ class TensorValue(BaseValue): """ return self._size + @property + def is_scalar(self) -> bool: + """Whether Value is scalar or not. + + Returns: + bool: True if scalar False otherwise + """ + return self.shape == () + def make_clear_tensor( dtype: BaseDataType, - shape: Optional[Tuple[int, ...]] = None, + shape: Tuple[int, ...], ) -> TensorValue: """Create a clear TensorValue. @@ -85,7 +96,7 @@ def make_clear_tensor( def make_encrypted_tensor( dtype: BaseDataType, - shape: Optional[Tuple[int, ...]] = None, + shape: Tuple[int, ...], ) -> TensorValue: """Create an encrypted TensorValue. @@ -101,3 +112,31 @@ def make_encrypted_tensor( ClearTensor = make_clear_tensor EncryptedTensor = make_encrypted_tensor + + +def make_clear_scalar(dtype: BaseDataType) -> TensorValue: + """Create a clear scalar value. + + Args: + dtype (BaseDataType): The data type for the value. + + Returns: + TensorValue: The corresponding TensorValue. + """ + return TensorValue(dtype=dtype, is_encrypted=False, shape=()) + + +def make_encrypted_scalar(dtype: BaseDataType) -> TensorValue: + """Create an encrypted scalar value. + + Args: + dtype (BaseDataType): The data type for the value. + + Returns: + TensorValue: The corresponding TensorValue. + """ + return TensorValue(dtype=dtype, is_encrypted=True, shape=()) + + +ClearScalar = make_clear_scalar +EncryptedScalar = make_encrypted_scalar diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 7e793806f..8adaa8c99 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -4,13 +4,6 @@ from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import Float, Float32, Float64, Integer, SignedInteger, UnsignedInteger from ..common.debugging import draw_graph, get_printable_graph from ..common.extensions.table import LookupTable -from ..common.values import ( - ClearScalar, - ClearTensor, - EncryptedScalar, - EncryptedTensor, - ScalarValue, - TensorValue, -) +from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue from .compile import compile_numpy_function, compile_numpy_function_into_op_graph from .tracing import trace_numpy_function diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index 72d3b52f2..8f61998cc 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -18,7 +18,7 @@ from ..common.data_types.dtypes_helpers import ( from ..common.data_types.floats import Float from ..common.data_types.integers import Integer from ..common.debugging.custom_assert import custom_assert -from ..common.values import BaseValue, ScalarValue, TensorValue +from ..common.values import BaseValue, TensorValue NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { numpy.dtype(numpy.int32): Integer(32, is_signed=True), @@ -158,10 +158,15 @@ def get_base_value_for_numpy_or_python_constant_data( with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method). """ constant_data_value: Callable[..., BaseValue] + custom_assert( + not isinstance(constant_data, list), + "Unsupported constant data of type list " + "(if you meant to use a list as an array, please use numpy.array instead)", + ) custom_assert( isinstance( constant_data, - (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES), + (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES), ), f"Unsupported constant data of type {type(constant_data)}", ) @@ -170,7 +175,7 @@ def get_base_value_for_numpy_or_python_constant_data( if isinstance(constant_data, numpy.ndarray): constant_data_value = partial(TensorValue, dtype=base_dtype, shape=constant_data.shape) elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES): - constant_data_value = partial(ScalarValue, dtype=base_dtype) + constant_data_value = partial(TensorValue, dtype=base_dtype, shape=()) else: constant_data_value = get_base_value_for_python_constant_data(constant_data) return constant_data_value diff --git a/tests/common/bounds_measurement/test_inputset_eval.py b/tests/common/bounds_measurement/test_inputset_eval.py index d622e9ee7..d977fce93 100644 --- a/tests/common/bounds_measurement/test_inputset_eval.py +++ b/tests/common/bounds_measurement/test_inputset_eval.py @@ -309,14 +309,21 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys): y = ClearTensor(UnsignedInteger(2), (3,)) inputset = [ - ([2, 1, 3, 1], [1, 2, 1, 1]), - ([3, 3, 3], [3, 3, 5]), + (np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])), + (np.array([3, 3, 3]), np.array([3, 3, 5])), ] op_graph = trace_numpy_function(f, {"x": x, "y": y}) configuration = CompilationConfiguration() - eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration) + eval_op_graph_bounds_on_inputset( + op_graph, + inputset, + compilation_configuration=configuration, + min_func=numpy_min_func, + max_func=numpy_max_func, + get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, + ) captured = capsys.readouterr() assert ( @@ -339,14 +346,21 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys): y = ClearTensor(UnsignedInteger(2), (3,)) inputset = [ - ([2, 1, 3, 1], [1, 2, 1, 1]), - ([3, 3, 3], [3, 3, 5]), + (np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])), + (np.array([3, 3, 3]), np.array([3, 3, 5])), ] op_graph = trace_numpy_function(f, {"x": x, "y": y}) configuration = CompilationConfiguration(check_every_input_in_inputset=True) - eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration) + eval_op_graph_bounds_on_inputset( + op_graph, + inputset, + compilation_configuration=configuration, + min_func=numpy_min_func, + max_func=numpy_max_func, + get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data, + ) captured = capsys.readouterr() assert ( diff --git a/tests/common/data_types/test_values.py b/tests/common/data_types/test_values.py index 6f18a59ec..25c24f488 100644 --- a/tests/common/data_types/test_values.py +++ b/tests/common/data_types/test_values.py @@ -31,7 +31,6 @@ class DummyDtype(BaseDataType): @pytest.mark.parametrize( "shape,expected_shape,expected_ndim,expected_size", [ - (None, (), 0, 1), ((), (), 0, 1), ((3, 256, 256), (3, 256, 256), 3, 196_608), ((1920, 1080, 3), (1920, 1080, 3), 3, 6_220_800), diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index 2ab61b91a..305d879ca 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -256,7 +256,10 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges): result_graph = compile_numpy_function_into_op_graph( func, args_dict, - (([data[0]] * n, [data[1]] * n) for data in datagen(*args_ranges)), + ( + (numpy.array([data[0]] * n), numpy.array([data[1]] * n)) + for data in datagen(*args_ranges) + ), ) converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) mlir_result = converter.convert(result_graph) @@ -289,7 +292,6 @@ def test_concrete_clear_integer_to_mlir_type(is_signed): @pytest.mark.parametrize( "shape", [ - None, (5,), (5, 8), (-1, 5), @@ -315,7 +317,6 @@ def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape): @pytest.mark.parametrize( "shape", [ - None, (5,), (5, 8), (-1, 5), diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 5f30b9838..8f141ae90 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -176,7 +176,9 @@ def test_compile_and_run_dot_correctness(size, input_range): def data_gen(input_range, size): for _ in range(1000): low, high = input_range - args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)] + args = [ + numpy.array([random.randint(low, high) for _ in range(size)]) for __ in range(2) + ] yield args @@ -303,7 +305,7 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str): iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat) iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat) for prod_i, prod_j in itertools.product(iter_i, iter_j): - yield (list(prod_i), list(prod_j)) + yield numpy.array(prod_i), numpy.array(prod_j) max_for_ij = 3 assert len(shape) == 1 diff --git a/tests/numpy/test_np_dtypes_helpers.py b/tests/numpy/test_np_dtypes_helpers.py index d48180657..a10e2b594 100644 --- a/tests/numpy/test_np_dtypes_helpers.py +++ b/tests/numpy/test_np_dtypes_helpers.py @@ -8,6 +8,7 @@ from concrete.common.data_types.integers import Integer from concrete.numpy.np_dtypes_helpers import ( convert_base_data_type_to_numpy_dtype, convert_numpy_dtype_to_base_data_type, + get_base_value_for_numpy_or_python_constant_data, get_type_constructor_for_numpy_or_python_constant_data, ) @@ -76,3 +77,14 @@ def test_get_type_constructor_for_numpy_or_python_constant_data( assert expected_constructor == get_type_constructor_for_numpy_or_python_constant_data( constant_data ) + + +def test_get_base_value_for_numpy_or_python_constant_data_with_list(): + """Test function for get_base_value_for_numpy_or_python_constant_data called with list""" + + with pytest.raises( + AssertionError, + match="Unsupported constant data of type list " + "\\(if you meant to use a list as an array, please use numpy\\.array instead\\)", + ): + get_base_value_for_numpy_or_python_constant_data([1, 2, 3])