mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: replace scalars with () shaped tensors, disable python list support in inputset
This commit is contained in:
@@ -2,18 +2,10 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, List, Union, cast
|
||||
from typing import Callable, Union, cast
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..values import (
|
||||
BaseValue,
|
||||
ClearScalar,
|
||||
ClearTensor,
|
||||
EncryptedScalar,
|
||||
EncryptedTensor,
|
||||
ScalarValue,
|
||||
TensorValue,
|
||||
)
|
||||
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
|
||||
from .base import BaseDataType
|
||||
from .floats import Float
|
||||
from .integers import Integer, get_bits_to_represent_value_as_integer
|
||||
@@ -24,25 +16,25 @@ BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES
|
||||
|
||||
|
||||
def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is an encrypted ScalarValue of type Integer.
|
||||
"""Check that a value is an encrypted scalar of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted ScalarValue of type Integer
|
||||
bool: True if the passed value_to_check is an encrypted scalar of type Integer
|
||||
"""
|
||||
return value_is_scalar_integer(value_to_check) and value_to_check.is_encrypted
|
||||
|
||||
|
||||
def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is an encrypted ScalarValue of type unsigned Integer.
|
||||
"""Check that a value is an encrypted scalar of type unsigned Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted ScalarValue of type Integer and
|
||||
bool: True if the passed value_to_check is an encrypted scalar of type Integer and
|
||||
unsigned
|
||||
"""
|
||||
return (
|
||||
@@ -52,28 +44,30 @@ def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> boo
|
||||
|
||||
|
||||
def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is a clear ScalarValue of type Integer.
|
||||
"""Check that a value is a clear scalar of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a clear ScalarValue of type Integer
|
||||
bool: True if the passed value_to_check is a clear scalar of type Integer
|
||||
"""
|
||||
return value_is_scalar_integer(value_to_check) and value_to_check.is_clear
|
||||
|
||||
|
||||
def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is a ScalarValue of type Integer.
|
||||
"""Check that a value is a scalar of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a ScalarValue of type Integer
|
||||
bool: True if the passed value_to_check is a scalar of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check, ScalarValue) and isinstance(
|
||||
value_to_check.dtype, INTEGER_TYPES
|
||||
return (
|
||||
isinstance(value_to_check, TensorValue)
|
||||
and value_to_check.is_scalar
|
||||
and isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
)
|
||||
|
||||
|
||||
@@ -126,8 +120,10 @@ def value_is_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a TensorValue of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check, TensorValue) and isinstance(
|
||||
value_to_check.dtype, INTEGER_TYPES
|
||||
return (
|
||||
isinstance(value_to_check, TensorValue)
|
||||
and not value_to_check.is_scalar
|
||||
and isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
)
|
||||
|
||||
|
||||
@@ -190,43 +186,6 @@ def find_type_to_hold_both_lossy(
|
||||
return type_to_return
|
||||
|
||||
|
||||
def mix_scalar_values_determine_holding_dtype(
|
||||
value1: ScalarValue,
|
||||
value2: ScalarValue,
|
||||
) -> ScalarValue:
|
||||
"""Return mixed ScalarValue with data type able to hold both value1 and value2 dtypes.
|
||||
|
||||
Returns a ScalarValue that would result from computation on both value1 and value2 while
|
||||
determining the data type able to hold both value1 and value2 data type (this can be lossy
|
||||
with floats).
|
||||
|
||||
Args:
|
||||
value1 (ScalarValue): first ScalarValue to mix.
|
||||
value2 (ScalarValue): second ScalarValue to mix.
|
||||
|
||||
Returns:
|
||||
ScalarValue: The resulting mixed ScalarValue with data type able to hold both value1 and
|
||||
value2 dtypes.
|
||||
"""
|
||||
|
||||
custom_assert(
|
||||
isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue"
|
||||
)
|
||||
custom_assert(
|
||||
isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue"
|
||||
)
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
|
||||
mixed_value: ScalarValue
|
||||
|
||||
if value1.is_encrypted or value2.is_encrypted:
|
||||
mixed_value = EncryptedScalar(holding_type)
|
||||
else:
|
||||
mixed_value = ClearScalar(holding_type)
|
||||
|
||||
return mixed_value
|
||||
|
||||
|
||||
def mix_tensor_values_determine_holding_dtype(
|
||||
value1: TensorValue,
|
||||
value2: TensorValue,
|
||||
@@ -284,7 +243,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
|
||||
value2 (BaseValue): second BaseValue to mix.
|
||||
|
||||
Raises:
|
||||
ValueError: raised if the BaseValue is not one of (ScalarValue, TensorValue)
|
||||
ValueError: raised if the BaseValue is not one of (TensorValue)
|
||||
|
||||
Returns:
|
||||
BaseValue: The resulting mixed BaseValue with data type able to hold both value1 and value2
|
||||
@@ -296,8 +255,6 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
|
||||
f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}",
|
||||
)
|
||||
|
||||
if isinstance(value1, ScalarValue) and isinstance(value2, ScalarValue):
|
||||
return mix_scalar_values_determine_holding_dtype(value1, value2)
|
||||
if isinstance(value1, TensorValue) and isinstance(value2, TensorValue):
|
||||
return mix_tensor_values_determine_holding_dtype(value1, value2)
|
||||
|
||||
@@ -306,9 +263,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
|
||||
)
|
||||
|
||||
|
||||
def get_base_data_type_for_python_constant_data(
|
||||
constant_data: Union[int, float, List[int], List[float]]
|
||||
) -> BaseDataType:
|
||||
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
|
||||
"""Determine the BaseDataType to hold the input constant data.
|
||||
|
||||
Args:
|
||||
@@ -320,28 +275,23 @@ def get_base_data_type_for_python_constant_data(
|
||||
"""
|
||||
constant_data_type: BaseDataType
|
||||
custom_assert(
|
||||
isinstance(constant_data, (int, float, list)),
|
||||
isinstance(constant_data, (int, float)),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
|
||||
if isinstance(constant_data, list):
|
||||
custom_assert(len(constant_data) > 0, "Data type of empty list cannot be detected")
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data[0])
|
||||
for value in constant_data:
|
||||
other_data_type = get_base_data_type_for_python_constant_data(value)
|
||||
constant_data_type = find_type_to_hold_both_lossy(constant_data_type, other_data_type)
|
||||
elif isinstance(constant_data, int):
|
||||
if isinstance(constant_data, int):
|
||||
is_signed = constant_data < 0
|
||||
constant_data_type = Integer(
|
||||
get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed
|
||||
)
|
||||
elif isinstance(constant_data, float):
|
||||
constant_data_type = Float(64)
|
||||
|
||||
return constant_data_type
|
||||
|
||||
|
||||
def get_base_value_for_python_constant_data(
|
||||
constant_data: Union[int, float, List[int], List[float]]
|
||||
constant_data: Union[int, float]
|
||||
) -> Callable[..., BaseValue]:
|
||||
"""Wrap the BaseDataType to hold the input constant data in BaseValue partial.
|
||||
|
||||
@@ -349,8 +299,8 @@ def get_base_value_for_python_constant_data(
|
||||
by calling it with the proper arguments forwarded to the BaseValue `__init__` function
|
||||
|
||||
Args:
|
||||
constant_data (Union[int, float, List[int], List[float]]): The constant data
|
||||
for which to determine the corresponding Value.
|
||||
constant_data (Union[int, float]): The constant data for which to determine the
|
||||
corresponding Value.
|
||||
|
||||
Returns:
|
||||
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when
|
||||
@@ -358,14 +308,8 @@ def get_base_value_for_python_constant_data(
|
||||
method).
|
||||
"""
|
||||
|
||||
if isinstance(constant_data, list):
|
||||
assert len(constant_data) > 0
|
||||
constant_shape = (len(constant_data),)
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return partial(TensorValue, dtype=constant_data_type, shape=constant_shape)
|
||||
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return partial(ScalarValue, dtype=constant_data_type)
|
||||
return partial(TensorValue, dtype=constant_data_type, shape=())
|
||||
|
||||
|
||||
def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]):
|
||||
|
||||
@@ -7,7 +7,6 @@ import zamalang
|
||||
from mlir.dialects import builtin
|
||||
from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module, RankedTensorType
|
||||
from mlir.ir import Type as MLIRType
|
||||
from mlir.ir import UnrankedTensorType
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from .. import values
|
||||
@@ -64,10 +63,7 @@ class MLIRConverter:
|
||||
MLIRType: corresponding MLIR type
|
||||
"""
|
||||
element_type = self._get_scalar_integer_type(bit_width, is_encrypted, is_signed)
|
||||
if len(shape): # randked tensor
|
||||
return RankedTensorType.get(shape, element_type)
|
||||
# unranked tensor
|
||||
return UnrankedTensorType.get(element_type)
|
||||
return RankedTensorType.get(shape, element_type)
|
||||
|
||||
def _get_scalar_integer_type(
|
||||
self, bit_width: int, is_encrypted: bool, is_signed: bool
|
||||
|
||||
@@ -9,7 +9,7 @@ from loguru import logger
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import (
|
||||
get_base_value_for_python_constant_data,
|
||||
mix_scalar_values_determine_holding_dtype,
|
||||
mix_values_determine_holding_dtype,
|
||||
)
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
@@ -43,7 +43,7 @@ class IntermediateNode(ABC):
|
||||
def _init_binary(
|
||||
self,
|
||||
inputs: Iterable[BaseValue],
|
||||
mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype,
|
||||
mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype,
|
||||
**_kwargs, # Required to conform to __init__ typing
|
||||
) -> None:
|
||||
"""__init__ for a binary operation, ie two inputs."""
|
||||
@@ -221,7 +221,11 @@ class ArbitraryFunction(IntermediateNode):
|
||||
self.arbitrary_func = arbitrary_func
|
||||
self.op_args = op_args if op_args is not None else ()
|
||||
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
|
||||
self.outputs = [input_base_value.__class__(output_dtype, input_base_value.is_encrypted)]
|
||||
|
||||
output = deepcopy(input_base_value)
|
||||
output.dtype = output_dtype
|
||||
self.outputs = [output]
|
||||
|
||||
self.op_name = op_name if op_name is not None else self.__class__.__name__
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Module for value structures."""
|
||||
|
||||
from . import scalars, tensors
|
||||
from . import tensors
|
||||
from .base import BaseValue
|
||||
from .scalars import ClearScalar, EncryptedScalar, ScalarValue
|
||||
from .tensors import ClearTensor, EncryptedTensor, TensorValue
|
||||
from .tensors import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Module that defines the scalar values in a program."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
from .base import BaseValue
|
||||
|
||||
|
||||
class ScalarValue(BaseValue):
|
||||
"""Class representing a scalar value."""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return BaseValue.__eq__(self, other)
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover
|
||||
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
|
||||
return f"{encrypted_str}Scalar<{self.dtype!r}>"
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""Return the ScalarValue shape property.
|
||||
|
||||
Returns:
|
||||
Tuple[int, ...]: The ScalarValue shape which is `()`.
|
||||
"""
|
||||
return ()
|
||||
|
||||
|
||||
def make_clear_scalar(dtype: BaseDataType) -> ScalarValue:
|
||||
"""Create a clear ScalarValue.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): The data type for the value.
|
||||
|
||||
Returns:
|
||||
ScalarValue: The corresponding ScalarValue.
|
||||
"""
|
||||
return ScalarValue(dtype=dtype, is_encrypted=False)
|
||||
|
||||
|
||||
def make_encrypted_scalar(dtype: BaseDataType) -> ScalarValue:
|
||||
"""Create an encrypted ScalarValue.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): The data type for the value.
|
||||
|
||||
Returns:
|
||||
ScalarValue: The corresponding ScalarValue.
|
||||
"""
|
||||
return ScalarValue(dtype=dtype, is_encrypted=True)
|
||||
|
||||
|
||||
ClearScalar = make_clear_scalar
|
||||
EncryptedScalar = make_encrypted_scalar
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module that defines the tensor values in a program."""
|
||||
|
||||
from math import prod
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
from .base import BaseValue
|
||||
@@ -18,13 +18,13 @@ class TensorValue(BaseValue):
|
||||
self,
|
||||
dtype: BaseDataType,
|
||||
is_encrypted: bool,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
) -> None:
|
||||
shape: Tuple[int, ...],
|
||||
):
|
||||
super().__init__(dtype, is_encrypted)
|
||||
# Managing tensors as in numpy, no shape or () is treated as a 0-D array of size 1
|
||||
self._shape = shape if shape is not None else ()
|
||||
# Managing tensors as in numpy, shape of () means the value is scalar
|
||||
self._shape = shape
|
||||
self._ndim = len(self._shape)
|
||||
self._size = prod(self._shape) if self._shape else 1
|
||||
self._size = prod(self._shape) if self._shape != () else 1
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
@@ -37,7 +37,9 @@ class TensorValue(BaseValue):
|
||||
|
||||
def __str__(self) -> str:
|
||||
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
|
||||
return f"{encrypted_str}Tensor<{str(self.dtype)}, shape={self.shape}>"
|
||||
tensor_or_scalar_str = "Scalar" if self.is_scalar else "Tensor"
|
||||
shape_str = f", shape={self.shape}" if self.shape != () else ""
|
||||
return f"{encrypted_str}{tensor_or_scalar_str}<{str(self.dtype)}{shape_str}>"
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
@@ -66,10 +68,19 @@ class TensorValue(BaseValue):
|
||||
"""
|
||||
return self._size
|
||||
|
||||
@property
|
||||
def is_scalar(self) -> bool:
|
||||
"""Whether Value is scalar or not.
|
||||
|
||||
Returns:
|
||||
bool: True if scalar False otherwise
|
||||
"""
|
||||
return self.shape == ()
|
||||
|
||||
|
||||
def make_clear_tensor(
|
||||
dtype: BaseDataType,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
shape: Tuple[int, ...],
|
||||
) -> TensorValue:
|
||||
"""Create a clear TensorValue.
|
||||
|
||||
@@ -85,7 +96,7 @@ def make_clear_tensor(
|
||||
|
||||
def make_encrypted_tensor(
|
||||
dtype: BaseDataType,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
shape: Tuple[int, ...],
|
||||
) -> TensorValue:
|
||||
"""Create an encrypted TensorValue.
|
||||
|
||||
@@ -101,3 +112,31 @@ def make_encrypted_tensor(
|
||||
|
||||
ClearTensor = make_clear_tensor
|
||||
EncryptedTensor = make_encrypted_tensor
|
||||
|
||||
|
||||
def make_clear_scalar(dtype: BaseDataType) -> TensorValue:
|
||||
"""Create a clear scalar value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): The data type for the value.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(dtype=dtype, is_encrypted=False, shape=())
|
||||
|
||||
|
||||
def make_encrypted_scalar(dtype: BaseDataType) -> TensorValue:
|
||||
"""Create an encrypted scalar value.
|
||||
|
||||
Args:
|
||||
dtype (BaseDataType): The data type for the value.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(dtype=dtype, is_encrypted=True, shape=())
|
||||
|
||||
|
||||
ClearScalar = make_clear_scalar
|
||||
EncryptedScalar = make_encrypted_scalar
|
||||
|
||||
@@ -4,13 +4,6 @@ from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Float, Float32, Float64, Integer, SignedInteger, UnsignedInteger
|
||||
from ..common.debugging import draw_graph, get_printable_graph
|
||||
from ..common.extensions.table import LookupTable
|
||||
from ..common.values import (
|
||||
ClearScalar,
|
||||
ClearTensor,
|
||||
EncryptedScalar,
|
||||
EncryptedTensor,
|
||||
ScalarValue,
|
||||
TensorValue,
|
||||
)
|
||||
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
|
||||
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph
|
||||
from .tracing import trace_numpy_function
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..common.data_types.dtypes_helpers import (
|
||||
from ..common.data_types.floats import Float
|
||||
from ..common.data_types.integers import Integer
|
||||
from ..common.debugging.custom_assert import custom_assert
|
||||
from ..common.values import BaseValue, ScalarValue, TensorValue
|
||||
from ..common.values import BaseValue, TensorValue
|
||||
|
||||
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
|
||||
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
|
||||
@@ -158,10 +158,15 @@ def get_base_value_for_numpy_or_python_constant_data(
|
||||
with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method).
|
||||
"""
|
||||
constant_data_value: Callable[..., BaseValue]
|
||||
custom_assert(
|
||||
not isinstance(constant_data, list),
|
||||
"Unsupported constant data of type list "
|
||||
"(if you meant to use a list as an array, please use numpy.array instead)",
|
||||
)
|
||||
custom_assert(
|
||||
isinstance(
|
||||
constant_data,
|
||||
(int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES),
|
||||
(int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES),
|
||||
),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
@@ -170,7 +175,7 @@ def get_base_value_for_numpy_or_python_constant_data(
|
||||
if isinstance(constant_data, numpy.ndarray):
|
||||
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=constant_data.shape)
|
||||
elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
|
||||
constant_data_value = partial(ScalarValue, dtype=base_dtype)
|
||||
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=())
|
||||
else:
|
||||
constant_data_value = get_base_value_for_python_constant_data(constant_data)
|
||||
return constant_data_value
|
||||
|
||||
@@ -309,14 +309,21 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
([2, 1, 3, 1], [1, 2, 1, 1]),
|
||||
([3, 3, 3], [3, 3, 5]),
|
||||
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 5])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration()
|
||||
eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration)
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
@@ -339,14 +346,21 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys):
|
||||
y = ClearTensor(UnsignedInteger(2), (3,))
|
||||
|
||||
inputset = [
|
||||
([2, 1, 3, 1], [1, 2, 1, 1]),
|
||||
([3, 3, 3], [3, 3, 5]),
|
||||
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
|
||||
(np.array([3, 3, 3]), np.array([3, 3, 5])),
|
||||
]
|
||||
|
||||
op_graph = trace_numpy_function(f, {"x": x, "y": y})
|
||||
|
||||
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
|
||||
eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration)
|
||||
eval_op_graph_bounds_on_inputset(
|
||||
op_graph,
|
||||
inputset,
|
||||
compilation_configuration=configuration,
|
||||
min_func=numpy_min_func,
|
||||
max_func=numpy_max_func,
|
||||
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
|
||||
@@ -31,7 +31,6 @@ class DummyDtype(BaseDataType):
|
||||
@pytest.mark.parametrize(
|
||||
"shape,expected_shape,expected_ndim,expected_size",
|
||||
[
|
||||
(None, (), 0, 1),
|
||||
((), (), 0, 1),
|
||||
((3, 256, 256), (3, 256, 256), 3, 196_608),
|
||||
((1920, 1080, 3), (1920, 1080, 3), 3, 6_220_800),
|
||||
|
||||
@@ -256,7 +256,10 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
|
||||
result_graph = compile_numpy_function_into_op_graph(
|
||||
func,
|
||||
args_dict,
|
||||
(([data[0]] * n, [data[1]] * n) for data in datagen(*args_ranges)),
|
||||
(
|
||||
(numpy.array([data[0]] * n), numpy.array([data[1]] * n))
|
||||
for data in datagen(*args_ranges)
|
||||
),
|
||||
)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
@@ -289,7 +292,6 @@ def test_concrete_clear_integer_to_mlir_type(is_signed):
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
None,
|
||||
(5,),
|
||||
(5, 8),
|
||||
(-1, 5),
|
||||
@@ -315,7 +317,6 @@ def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape):
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
None,
|
||||
(5,),
|
||||
(5, 8),
|
||||
(-1, 5),
|
||||
|
||||
@@ -176,7 +176,9 @@ def test_compile_and_run_dot_correctness(size, input_range):
|
||||
def data_gen(input_range, size):
|
||||
for _ in range(1000):
|
||||
low, high = input_range
|
||||
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
|
||||
args = [
|
||||
numpy.array([random.randint(low, high) for _ in range(size)]) for __ in range(2)
|
||||
]
|
||||
|
||||
yield args
|
||||
|
||||
@@ -303,7 +305,7 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
|
||||
iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
|
||||
for prod_i, prod_j in itertools.product(iter_i, iter_j):
|
||||
yield (list(prod_i), list(prod_j))
|
||||
yield numpy.array(prod_i), numpy.array(prod_j)
|
||||
|
||||
max_for_ij = 3
|
||||
assert len(shape) == 1
|
||||
|
||||
@@ -8,6 +8,7 @@ from concrete.common.data_types.integers import Integer
|
||||
from concrete.numpy.np_dtypes_helpers import (
|
||||
convert_base_data_type_to_numpy_dtype,
|
||||
convert_numpy_dtype_to_base_data_type,
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_type_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
|
||||
@@ -76,3 +77,14 @@ def test_get_type_constructor_for_numpy_or_python_constant_data(
|
||||
assert expected_constructor == get_type_constructor_for_numpy_or_python_constant_data(
|
||||
constant_data
|
||||
)
|
||||
|
||||
|
||||
def test_get_base_value_for_numpy_or_python_constant_data_with_list():
|
||||
"""Test function for get_base_value_for_numpy_or_python_constant_data called with list"""
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Unsupported constant data of type list "
|
||||
"\\(if you meant to use a list as an array, please use numpy\\.array instead\\)",
|
||||
):
|
||||
get_base_value_for_numpy_or_python_constant_data([1, 2, 3])
|
||||
|
||||
Reference in New Issue
Block a user