refactor: replace scalars with () shaped tensors, disable python list support in inputset

This commit is contained in:
Umut
2021-09-29 13:22:14 +03:00
parent f97682bd23
commit c47dac833b
13 changed files with 134 additions and 180 deletions

View File

@@ -2,18 +2,10 @@
from copy import deepcopy
from functools import partial
from typing import Callable, List, Union, cast
from typing import Callable, Union, cast
from ..debugging.custom_assert import custom_assert
from ..values import (
BaseValue,
ClearScalar,
ClearTensor,
EncryptedScalar,
EncryptedTensor,
ScalarValue,
TensorValue,
)
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
from .base import BaseDataType
from .floats import Float
from .integers import Integer, get_bits_to_represent_value_as_integer
@@ -24,25 +16,25 @@ BASE_DATA_TYPES = INTEGER_TYPES + FLOAT_TYPES
def value_is_encrypted_scalar_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is an encrypted ScalarValue of type Integer.
"""Check that a value is an encrypted scalar of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted ScalarValue of type Integer
bool: True if the passed value_to_check is an encrypted scalar of type Integer
"""
return value_is_scalar_integer(value_to_check) and value_to_check.is_encrypted
def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is an encrypted ScalarValue of type unsigned Integer.
"""Check that a value is an encrypted scalar of type unsigned Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted ScalarValue of type Integer and
bool: True if the passed value_to_check is an encrypted scalar of type Integer and
unsigned
"""
return (
@@ -52,28 +44,30 @@ def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> boo
def value_is_clear_scalar_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is a clear ScalarValue of type Integer.
"""Check that a value is a clear scalar of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a clear ScalarValue of type Integer
bool: True if the passed value_to_check is a clear scalar of type Integer
"""
return value_is_scalar_integer(value_to_check) and value_to_check.is_clear
def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is a ScalarValue of type Integer.
"""Check that a value is a scalar of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a ScalarValue of type Integer
bool: True if the passed value_to_check is a scalar of type Integer
"""
return isinstance(value_to_check, ScalarValue) and isinstance(
value_to_check.dtype, INTEGER_TYPES
return (
isinstance(value_to_check, TensorValue)
and value_to_check.is_scalar
and isinstance(value_to_check.dtype, INTEGER_TYPES)
)
@@ -126,8 +120,10 @@ def value_is_tensor_integer(value_to_check: BaseValue) -> bool:
Returns:
bool: True if the passed value_to_check is a TensorValue of type Integer
"""
return isinstance(value_to_check, TensorValue) and isinstance(
value_to_check.dtype, INTEGER_TYPES
return (
isinstance(value_to_check, TensorValue)
and not value_to_check.is_scalar
and isinstance(value_to_check.dtype, INTEGER_TYPES)
)
@@ -190,43 +186,6 @@ def find_type_to_hold_both_lossy(
return type_to_return
def mix_scalar_values_determine_holding_dtype(
value1: ScalarValue,
value2: ScalarValue,
) -> ScalarValue:
"""Return mixed ScalarValue with data type able to hold both value1 and value2 dtypes.
Returns a ScalarValue that would result from computation on both value1 and value2 while
determining the data type able to hold both value1 and value2 data type (this can be lossy
with floats).
Args:
value1 (ScalarValue): first ScalarValue to mix.
value2 (ScalarValue): second ScalarValue to mix.
Returns:
ScalarValue: The resulting mixed ScalarValue with data type able to hold both value1 and
value2 dtypes.
"""
custom_assert(
isinstance(value1, ScalarValue), f"Unsupported value1: {value1}, expected ScalarValue"
)
custom_assert(
isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue"
)
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
mixed_value: ScalarValue
if value1.is_encrypted or value2.is_encrypted:
mixed_value = EncryptedScalar(holding_type)
else:
mixed_value = ClearScalar(holding_type)
return mixed_value
def mix_tensor_values_determine_holding_dtype(
value1: TensorValue,
value2: TensorValue,
@@ -284,7 +243,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
value2 (BaseValue): second BaseValue to mix.
Raises:
ValueError: raised if the BaseValue is not one of (ScalarValue, TensorValue)
ValueError: raised if the BaseValue is not one of (TensorValue)
Returns:
BaseValue: The resulting mixed BaseValue with data type able to hold both value1 and value2
@@ -296,8 +255,6 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
f"Cannot mix values of different types: value 1:{type(value1)}, value2: {type(value2)}",
)
if isinstance(value1, ScalarValue) and isinstance(value2, ScalarValue):
return mix_scalar_values_determine_holding_dtype(value1, value2)
if isinstance(value1, TensorValue) and isinstance(value2, TensorValue):
return mix_tensor_values_determine_holding_dtype(value1, value2)
@@ -306,9 +263,7 @@ def mix_values_determine_holding_dtype(value1: BaseValue, value2: BaseValue) ->
)
def get_base_data_type_for_python_constant_data(
constant_data: Union[int, float, List[int], List[float]]
) -> BaseDataType:
def get_base_data_type_for_python_constant_data(constant_data: Union[int, float]) -> BaseDataType:
"""Determine the BaseDataType to hold the input constant data.
Args:
@@ -320,28 +275,23 @@ def get_base_data_type_for_python_constant_data(
"""
constant_data_type: BaseDataType
custom_assert(
isinstance(constant_data, (int, float, list)),
isinstance(constant_data, (int, float)),
f"Unsupported constant data of type {type(constant_data)}",
)
if isinstance(constant_data, list):
custom_assert(len(constant_data) > 0, "Data type of empty list cannot be detected")
constant_data_type = get_base_data_type_for_python_constant_data(constant_data[0])
for value in constant_data:
other_data_type = get_base_data_type_for_python_constant_data(value)
constant_data_type = find_type_to_hold_both_lossy(constant_data_type, other_data_type)
elif isinstance(constant_data, int):
if isinstance(constant_data, int):
is_signed = constant_data < 0
constant_data_type = Integer(
get_bits_to_represent_value_as_integer(constant_data, is_signed), is_signed
)
elif isinstance(constant_data, float):
constant_data_type = Float(64)
return constant_data_type
def get_base_value_for_python_constant_data(
constant_data: Union[int, float, List[int], List[float]]
constant_data: Union[int, float]
) -> Callable[..., BaseValue]:
"""Wrap the BaseDataType to hold the input constant data in BaseValue partial.
@@ -349,8 +299,8 @@ def get_base_value_for_python_constant_data(
by calling it with the proper arguments forwarded to the BaseValue `__init__` function
Args:
constant_data (Union[int, float, List[int], List[float]]): The constant data
for which to determine the corresponding Value.
constant_data (Union[int, float]): The constant data for which to determine the
corresponding Value.
Returns:
Callable[..., BaseValue]: A partial object that will return the proper BaseValue when
@@ -358,14 +308,8 @@ def get_base_value_for_python_constant_data(
method).
"""
if isinstance(constant_data, list):
assert len(constant_data) > 0
constant_shape = (len(constant_data),)
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
return partial(TensorValue, dtype=constant_data_type, shape=constant_shape)
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
return partial(ScalarValue, dtype=constant_data_type)
return partial(TensorValue, dtype=constant_data_type, shape=())
def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]):

View File

@@ -7,7 +7,6 @@ import zamalang
from mlir.dialects import builtin
from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module, RankedTensorType
from mlir.ir import Type as MLIRType
from mlir.ir import UnrankedTensorType
from zamalang.dialects import hlfhe
from .. import values
@@ -64,10 +63,7 @@ class MLIRConverter:
MLIRType: corresponding MLIR type
"""
element_type = self._get_scalar_integer_type(bit_width, is_encrypted, is_signed)
if len(shape): # randked tensor
return RankedTensorType.get(shape, element_type)
# unranked tensor
return UnrankedTensorType.get(element_type)
return RankedTensorType.get(shape, element_type)
def _get_scalar_integer_type(
self, bit_width: int, is_encrypted: bool, is_signed: bool

View File

@@ -9,7 +9,7 @@ from loguru import logger
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import (
get_base_value_for_python_constant_data,
mix_scalar_values_determine_holding_dtype,
mix_values_determine_holding_dtype,
)
from ..data_types.integers import Integer
from ..debugging.custom_assert import custom_assert
@@ -43,7 +43,7 @@ class IntermediateNode(ABC):
def _init_binary(
self,
inputs: Iterable[BaseValue],
mix_values_func: Callable[..., BaseValue] = mix_scalar_values_determine_holding_dtype,
mix_values_func: Callable[..., BaseValue] = mix_values_determine_holding_dtype,
**_kwargs, # Required to conform to __init__ typing
) -> None:
"""__init__ for a binary operation, ie two inputs."""
@@ -221,7 +221,11 @@ class ArbitraryFunction(IntermediateNode):
self.arbitrary_func = arbitrary_func
self.op_args = op_args if op_args is not None else ()
self.op_kwargs = op_kwargs if op_kwargs is not None else {}
self.outputs = [input_base_value.__class__(output_dtype, input_base_value.is_encrypted)]
output = deepcopy(input_base_value)
output.dtype = output_dtype
self.outputs = [output]
self.op_name = op_name if op_name is not None else self.__class__.__name__
def evaluate(self, inputs: Dict[int, Any]) -> Any:

View File

@@ -1,6 +1,5 @@
"""Module for value structures."""
from . import scalars, tensors
from . import tensors
from .base import BaseValue
from .scalars import ClearScalar, EncryptedScalar, ScalarValue
from .tensors import ClearTensor, EncryptedTensor, TensorValue
from .tensors import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue

View File

@@ -1,54 +0,0 @@
"""Module that defines the scalar values in a program."""
from typing import Tuple
from ..data_types.base import BaseDataType
from .base import BaseValue
class ScalarValue(BaseValue):
"""Class representing a scalar value."""
def __eq__(self, other: object) -> bool:
return BaseValue.__eq__(self, other)
def __str__(self) -> str: # pragma: no cover
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
return f"{encrypted_str}Scalar<{self.dtype!r}>"
@property
def shape(self) -> Tuple[int, ...]:
"""Return the ScalarValue shape property.
Returns:
Tuple[int, ...]: The ScalarValue shape which is `()`.
"""
return ()
def make_clear_scalar(dtype: BaseDataType) -> ScalarValue:
"""Create a clear ScalarValue.
Args:
dtype (BaseDataType): The data type for the value.
Returns:
ScalarValue: The corresponding ScalarValue.
"""
return ScalarValue(dtype=dtype, is_encrypted=False)
def make_encrypted_scalar(dtype: BaseDataType) -> ScalarValue:
"""Create an encrypted ScalarValue.
Args:
dtype (BaseDataType): The data type for the value.
Returns:
ScalarValue: The corresponding ScalarValue.
"""
return ScalarValue(dtype=dtype, is_encrypted=True)
ClearScalar = make_clear_scalar
EncryptedScalar = make_encrypted_scalar

View File

@@ -1,7 +1,7 @@
"""Module that defines the tensor values in a program."""
from math import prod
from typing import Optional, Tuple
from typing import Tuple
from ..data_types.base import BaseDataType
from .base import BaseValue
@@ -18,13 +18,13 @@ class TensorValue(BaseValue):
self,
dtype: BaseDataType,
is_encrypted: bool,
shape: Optional[Tuple[int, ...]] = None,
) -> None:
shape: Tuple[int, ...],
):
super().__init__(dtype, is_encrypted)
# Managing tensors as in numpy, no shape or () is treated as a 0-D array of size 1
self._shape = shape if shape is not None else ()
# Managing tensors as in numpy, shape of () means the value is scalar
self._shape = shape
self._ndim = len(self._shape)
self._size = prod(self._shape) if self._shape else 1
self._size = prod(self._shape) if self._shape != () else 1
def __eq__(self, other: object) -> bool:
return (
@@ -37,7 +37,9 @@ class TensorValue(BaseValue):
def __str__(self) -> str:
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
return f"{encrypted_str}Tensor<{str(self.dtype)}, shape={self.shape}>"
tensor_or_scalar_str = "Scalar" if self.is_scalar else "Tensor"
shape_str = f", shape={self.shape}" if self.shape != () else ""
return f"{encrypted_str}{tensor_or_scalar_str}<{str(self.dtype)}{shape_str}>"
@property
def shape(self) -> Tuple[int, ...]:
@@ -66,10 +68,19 @@ class TensorValue(BaseValue):
"""
return self._size
@property
def is_scalar(self) -> bool:
"""Whether Value is scalar or not.
Returns:
bool: True if scalar False otherwise
"""
return self.shape == ()
def make_clear_tensor(
dtype: BaseDataType,
shape: Optional[Tuple[int, ...]] = None,
shape: Tuple[int, ...],
) -> TensorValue:
"""Create a clear TensorValue.
@@ -85,7 +96,7 @@ def make_clear_tensor(
def make_encrypted_tensor(
dtype: BaseDataType,
shape: Optional[Tuple[int, ...]] = None,
shape: Tuple[int, ...],
) -> TensorValue:
"""Create an encrypted TensorValue.
@@ -101,3 +112,31 @@ def make_encrypted_tensor(
ClearTensor = make_clear_tensor
EncryptedTensor = make_encrypted_tensor
def make_clear_scalar(dtype: BaseDataType) -> TensorValue:
"""Create a clear scalar value.
Args:
dtype (BaseDataType): The data type for the value.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(dtype=dtype, is_encrypted=False, shape=())
def make_encrypted_scalar(dtype: BaseDataType) -> TensorValue:
"""Create an encrypted scalar value.
Args:
dtype (BaseDataType): The data type for the value.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(dtype=dtype, is_encrypted=True, shape=())
ClearScalar = make_clear_scalar
EncryptedScalar = make_encrypted_scalar

View File

@@ -4,13 +4,6 @@ from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Float, Float32, Float64, Integer, SignedInteger, UnsignedInteger
from ..common.debugging import draw_graph, get_printable_graph
from ..common.extensions.table import LookupTable
from ..common.values import (
ClearScalar,
ClearTensor,
EncryptedScalar,
EncryptedTensor,
ScalarValue,
TensorValue,
)
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph
from .tracing import trace_numpy_function

View File

@@ -18,7 +18,7 @@ from ..common.data_types.dtypes_helpers import (
from ..common.data_types.floats import Float
from ..common.data_types.integers import Integer
from ..common.debugging.custom_assert import custom_assert
from ..common.values import BaseValue, ScalarValue, TensorValue
from ..common.values import BaseValue, TensorValue
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
@@ -158,10 +158,15 @@ def get_base_value_for_numpy_or_python_constant_data(
with `encrypted` as keyword argument (forwarded to the BaseValue `__init__` method).
"""
constant_data_value: Callable[..., BaseValue]
custom_assert(
not isinstance(constant_data, list),
"Unsupported constant data of type list "
"(if you meant to use a list as an array, please use numpy.array instead)",
)
custom_assert(
isinstance(
constant_data,
(int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES),
(int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES),
),
f"Unsupported constant data of type {type(constant_data)}",
)
@@ -170,7 +175,7 @@ def get_base_value_for_numpy_or_python_constant_data(
if isinstance(constant_data, numpy.ndarray):
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=constant_data.shape)
elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
constant_data_value = partial(ScalarValue, dtype=base_dtype)
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=())
else:
constant_data_value = get_base_value_for_python_constant_data(constant_data)
return constant_data_value

View File

@@ -309,14 +309,21 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
([2, 1, 3, 1], [1, 2, 1, 1]),
([3, 3, 3], [3, 3, 5]),
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 5])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration()
eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration)
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
captured = capsys.readouterr()
assert (
@@ -339,14 +346,21 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys):
y = ClearTensor(UnsignedInteger(2), (3,))
inputset = [
([2, 1, 3, 1], [1, 2, 1, 1]),
([3, 3, 3], [3, 3, 5]),
(np.array([2, 1, 3, 1]), np.array([1, 2, 1, 1])),
(np.array([3, 3, 3]), np.array([3, 3, 5])),
]
op_graph = trace_numpy_function(f, {"x": x, "y": y})
configuration = CompilationConfiguration(check_every_input_in_inputset=True)
eval_op_graph_bounds_on_inputset(op_graph, inputset, compilation_configuration=configuration)
eval_op_graph_bounds_on_inputset(
op_graph,
inputset,
compilation_configuration=configuration,
min_func=numpy_min_func,
max_func=numpy_max_func,
get_base_value_for_constant_data_func=get_base_value_for_numpy_or_python_constant_data,
)
captured = capsys.readouterr()
assert (

View File

@@ -31,7 +31,6 @@ class DummyDtype(BaseDataType):
@pytest.mark.parametrize(
"shape,expected_shape,expected_ndim,expected_size",
[
(None, (), 0, 1),
((), (), 0, 1),
((3, 256, 256), (3, 256, 256), 3, 196_608),
((1920, 1080, 3), (1920, 1080, 3), 3, 6_220_800),

View File

@@ -256,7 +256,10 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
result_graph = compile_numpy_function_into_op_graph(
func,
args_dict,
(([data[0]] * n, [data[1]] * n) for data in datagen(*args_ranges)),
(
(numpy.array([data[0]] * n), numpy.array([data[1]] * n))
for data in datagen(*args_ranges)
),
)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
@@ -289,7 +292,6 @@ def test_concrete_clear_integer_to_mlir_type(is_signed):
@pytest.mark.parametrize(
"shape",
[
None,
(5,),
(5, 8),
(-1, 5),
@@ -315,7 +317,6 @@ def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape):
@pytest.mark.parametrize(
"shape",
[
None,
(5,),
(5, 8),
(-1, 5),

View File

@@ -176,7 +176,9 @@ def test_compile_and_run_dot_correctness(size, input_range):
def data_gen(input_range, size):
for _ in range(1000):
low, high = input_range
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
args = [
numpy.array([random.randint(low, high) for _ in range(size)]) for __ in range(2)
]
yield args
@@ -303,7 +305,7 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
for prod_i, prod_j in itertools.product(iter_i, iter_j):
yield (list(prod_i), list(prod_j))
yield numpy.array(prod_i), numpy.array(prod_j)
max_for_ij = 3
assert len(shape) == 1

View File

@@ -8,6 +8,7 @@ from concrete.common.data_types.integers import Integer
from concrete.numpy.np_dtypes_helpers import (
convert_base_data_type_to_numpy_dtype,
convert_numpy_dtype_to_base_data_type,
get_base_value_for_numpy_or_python_constant_data,
get_type_constructor_for_numpy_or_python_constant_data,
)
@@ -76,3 +77,14 @@ def test_get_type_constructor_for_numpy_or_python_constant_data(
assert expected_constructor == get_type_constructor_for_numpy_or_python_constant_data(
constant_data
)
def test_get_base_value_for_numpy_or_python_constant_data_with_list():
"""Test function for get_base_value_for_numpy_or_python_constant_data called with list"""
with pytest.raises(
AssertionError,
match="Unsupported constant data of type list "
"\\(if you meant to use a list as an array, please use numpy\\.array instead\\)",
):
get_base_value_for_numpy_or_python_constant_data([1, 2, 3])