refactor(frontend-python): rename Value to ValueDescription

This commit is contained in:
Umut
2023-06-09 14:16:45 +02:00
parent 9c077852bb
commit 8a3e24d204
21 changed files with 167 additions and 145 deletions

View File

@@ -15,7 +15,7 @@ from concrete.compiler import EvaluationKeys, ValueDecrypter, ValueExporter
from ..dtypes.integer import SignedInteger, UnsignedInteger
from ..internal.utils import assert_that
from ..values.value import Value
from ..values import ValueDescription
from .data import Data
from .keys import Keys
from .specs import ClientSpecs
@@ -160,7 +160,7 @@ class Client:
is_encrypted = spec["encryption"] is not None
expected_dtype = SignedInteger(width) if is_signed else UnsignedInteger(width)
expected_value = Value(expected_dtype, shape, is_encrypted)
expected_value = ValueDescription(expected_dtype, shape, is_encrypted)
if is_valid:
expected_min = expected_dtype.min()
expected_max = expected_dtype.max()
@@ -184,7 +184,7 @@ class Client:
sanitized_args[index] = arg
if not is_valid:
actual_value = Value.of(arg, is_encrypted=is_encrypted)
actual_value = ValueDescription.of(arg, is_encrypted=is_encrypted)
message = (
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
)

View File

@@ -15,7 +15,7 @@ from ..extensions import AutoRounder
from ..mlir import GraphConverter
from ..representation import Graph
from ..tracing import Tracer
from ..values import Value
from ..values import ValueDescription
from .artifacts import DebugArtifacts
from .circuit import Circuit
from .configuration import Configuration
@@ -47,12 +47,12 @@ class Compiler:
graph: Optional[Graph]
_is_direct: bool
_parameter_values: Dict[str, Value]
_parameter_values: Dict[str, ValueDescription]
@staticmethod
def assemble(
function: Callable,
parameter_values: Dict[str, Value],
parameter_values: Dict[str, ValueDescription],
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
@@ -64,7 +64,7 @@ class Compiler:
function (Callable):
function to convert to a circuit
parameter_values (Dict[str, Value]):
parameter_values (Dict[str, ValueDescription]):
parameter values of the function
configuration(Optional[Configuration], default = None):
@@ -197,7 +197,7 @@ class Compiler:
self.artifacts.add_parameter_encryption_status(param, encryption_status)
parameters = {
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
param: ValueDescription.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
for arg, (param, status) in zip(
sample if len(self.parameter_encryption_statuses) > 1 else (sample,),
self.parameter_encryption_statuses.items(),

View File

@@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Unio
from ..representation import Graph
from ..tracing.typing import ScalarAnnotation
from ..values import Value
from ..values import ValueDescription
from .artifacts import DebugArtifacts
from .circuit import Circuit
from .compiler import Compiler, EncryptionStatus
@@ -40,14 +40,14 @@ def circuit(
def decoration(function: Callable):
signature = inspect.signature(function)
parameter_values: Dict[str, Value] = {}
parameter_values: Dict[str, ValueDescription] = {}
for name, details in signature.parameters.items():
if name not in parameters:
continue
annotation = details.annotation
is_value = isinstance(annotation, Value)
is_value = isinstance(annotation, ValueDescription)
is_scalar_annotation = isinstance(annotation, type) and issubclass(
annotation, ScalarAnnotation
)
@@ -61,7 +61,9 @@ def circuit(
raise ValueError(message)
parameter_values[name] = (
annotation if is_value else Value(annotation.dtype, shape=(), is_encrypted=False)
annotation
if is_value
else ValueDescription(annotation.dtype, shape=(), is_encrypted=False)
)
status = EncryptionStatus(parameters[name].lower())

View File

@@ -10,7 +10,7 @@ import numpy as np
from ..dtypes.utils import combine_dtypes
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
from ..values import ValueDescription
def array(values: Any) -> Union[np.ndarray, Tracer]:
@@ -54,7 +54,7 @@ def array(values: Any) -> Union[np.ndarray, Tracer]:
computation = Node.generic(
"array",
[deepcopy(value.output) for value in values],
Value(dtype, shape, is_encrypted),
ValueDescription(dtype, shape, is_encrypted),
lambda *args: np.array(args).reshape(shape),
)
return Tracer(computation, values)

View File

@@ -11,7 +11,7 @@ import torch
from ..internal.utils import assert_that
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
from ..values import ValueDescription
# pylint: disable=too-many-branches,too-many-statements
@@ -284,7 +284,7 @@ def _trace_or_evaluate(
return _evaluate(x, kernel_shape, strides, pads, dilations, ceil_mode == 1)
result = _evaluate(np.zeros(x.shape), kernel_shape, strides, pads, dilations, ceil_mode == 1)
resulting_value = Value.of(result)
resulting_value = ValueDescription.of(result)
resulting_value.is_encrypted = x.output.is_encrypted
resulting_value.dtype = x.output.dtype

View File

@@ -8,7 +8,7 @@ import numpy as np
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
from ..values import ValueDescription
def ones(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
@@ -35,7 +35,7 @@ def ones(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
computation = Node.generic(
"ones",
[],
Value.of(numpy_ones, is_encrypted=True),
ValueDescription.of(numpy_ones, is_encrypted=True),
lambda: np.ones(shape, dtype=np.int64),
)
return Tracer(computation, [])

View File

@@ -12,7 +12,7 @@ from ..dtypes import Integer
from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
from ..values import ValueDescription
local = threading.local()
@@ -103,7 +103,7 @@ class AutoRounder:
rounder.input_min = min(rounder.input_min, adjuster.input_min)
rounder.input_max = max(rounder.input_max, adjuster.input_max)
input_value = Value.of([rounder.input_min, rounder.input_max])
input_value = ValueDescription.of([rounder.input_min, rounder.input_max])
assert isinstance(input_value.dtype, Integer)
rounder.input_bit_width = input_value.dtype.bit_width

View File

@@ -10,7 +10,7 @@ import numpy as np
from ..dtypes import BaseDataType, Float
from ..representation import Node
from ..tracing import ScalarAnnotation, Tracer
from ..values import Value
from ..values import ValueDescription
def univariate(
@@ -55,7 +55,7 @@ def univariate(
sample = dtype(1) if x.output.is_scalar else np.ones(x.output.shape, dtype=dtype)
evaluation = function(sample)
output_value = Value.of(evaluation, is_encrypted=x.output.is_encrypted)
output_value = ValueDescription.of(evaluation, is_encrypted=x.output.is_encrypted)
if output_value.shape != x.output.shape:
message = f"Function {function.__name__} cannot be used with fhe.univariate"
raise ValueError(message)

View File

@@ -8,7 +8,7 @@ import numpy as np
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
from ..values import ValueDescription
def zeros(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
@@ -35,7 +35,7 @@ def zeros(shape: Union[int, Tuple[int, ...]]) -> Union[np.ndarray, Tracer]:
computation = Node.generic(
"zeros",
[],
Value.of(numpy_zeros, is_encrypted=True),
ValueDescription.of(numpy_zeros, is_encrypted=True),
lambda: np.zeros(shape, dtype=np.int64),
)
return Tracer(computation, [])

View File

@@ -25,7 +25,7 @@ from mlir.ir import Type as MlirType
from ..dtypes import Integer
from ..representation import Graph, Node
from ..values import Value
from ..values import ValueDescription
from .conversion import Conversion, ConversionType
from .processors import GraphProcessor
from .utils import MAXIMUM_TLU_BIT_WIDTH, _FromElementsOp
@@ -90,7 +90,7 @@ class Context:
"""
return ConversionType(RankedTensorType.get(shape, element_type.mlir))
def typeof(self, value: Union[Value, Node]) -> ConversionType:
def typeof(self, value: Union[ValueDescription, Node]) -> ConversionType:
"""
Get type corresponding to a value or a node.
"""
@@ -363,7 +363,7 @@ class Context:
continue
encrypted_element_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(
is_signed=resulting_type.is_signed,
bit_width=resulting_type.bit_width,
@@ -438,7 +438,7 @@ class Context:
if x.is_encrypted and y.is_clear:
encrypted_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width),
shape=y.shape,
is_encrypted=True,
@@ -525,7 +525,7 @@ class Context:
if x.original_bit_width + y.original_bit_width <= resulting_type.bit_width:
shifter_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=False, bit_width=(x.bit_width + 1)),
shape=(),
is_encrypted=False,
@@ -611,7 +611,7 @@ class Context:
return x
resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width),
shape=shape,
is_encrypted=x.is_encrypted,
@@ -637,7 +637,7 @@ class Context:
for x in xs:
if x.is_clear:
encrypted_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(
is_signed=resulting_type.is_signed,
bit_width=resulting_type.bit_width,
@@ -805,7 +805,7 @@ class Context:
y = self.to_signed(y)
signed_resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
shape=resulting_type.shape,
is_encrypted=resulting_type.is_encrypted,
@@ -1058,7 +1058,7 @@ class Context:
intermediate_shape.insert(dimension, 1)
intermediate_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width),
shape=tuple(intermediate_shape),
is_encrypted=x.is_encrypted,
@@ -1373,7 +1373,7 @@ class Context:
y = self.to_signed(y)
signed_resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
shape=resulting_type.shape,
is_encrypted=resulting_type.is_encrypted,
@@ -1425,7 +1425,7 @@ class Context:
)
same_signedness_resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width),
shape=resulting_type.shape,
is_encrypted=True,
@@ -1495,7 +1495,7 @@ class Context:
y = self.to_signed(y)
signed_resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=True, bit_width=resulting_type.bit_width),
shape=resulting_type.shape,
is_encrypted=resulting_type.is_encrypted,
@@ -1584,7 +1584,7 @@ class Context:
assert resulting_type.is_encrypted
one_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=False, bit_width=(resulting_type.bit_width + 1)),
shape=(),
is_encrypted=False,
@@ -1605,7 +1605,7 @@ class Context:
return x
resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width),
shape=output_shape,
is_encrypted=x.is_encrypted,
@@ -1699,7 +1699,7 @@ class Context:
)
flattened_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width),
shape=(int(np.prod(input_shape)),),
is_encrypted=x.is_encrypted,
@@ -1742,7 +1742,7 @@ class Context:
assert x.bit_width > lsbs_to_remove
resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=(x.bit_width - lsbs_to_remove)),
shape=x.shape,
is_encrypted=x.is_encrypted,
@@ -1814,7 +1814,7 @@ class Context:
if x.original_bit_width + b.original_bit_width <= bit_width:
shift_multiplier_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=False, bit_width=(x.bit_width + 1)),
shape=(),
is_encrypted=False,
@@ -1979,7 +1979,7 @@ class Context:
axes[i] += input_dimensions
same_signedness_resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width),
shape=resulting_type.shape,
is_encrypted=True,
@@ -2003,7 +2003,7 @@ class Context:
return x
resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=x.bit_width),
shape=(1,),
is_encrypted=x.is_encrypted,
@@ -2039,7 +2039,7 @@ class Context:
return x
resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=True, bit_width=x.bit_width),
shape=x.shape,
is_encrypted=True,
@@ -2064,7 +2064,7 @@ class Context:
return x
resulting_type = self.typeof(
Value(
ValueDescription(
dtype=Integer(is_signed=False, bit_width=x.bit_width),
shape=x.shape,
is_encrypted=True,

View File

@@ -476,7 +476,7 @@ class Graph:
def update_with_bounds(self, bounds: Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
"""
Update `Value`s within the `Graph` according to measured bounds.
Update `ValueDescription`s within the `Graph` according to measured bounds.
Args:
bounds (Dict[Node, Dict[str, Union[np.integer, np.floating]]]):

View File

@@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from ..internal.utils import assert_that
from ..values import Value
from ..values import ValueDescription
from .evaluator import ConstantEvaluator, GenericEvaluator, GenericTupleEvaluator, InputEvaluator
from .operation import Operation
from .utils import KWARGS_IGNORED_IN_FORMATTING, format_constant, format_indexing_element
@@ -22,8 +22,8 @@ class Node:
Node class, to represent computation in a computation graph.
"""
inputs: List[Value]
output: Value
inputs: List[ValueDescription]
output: ValueDescription
operation: Operation
evaluator: Callable
@@ -54,7 +54,7 @@ class Node:
"""
try:
value = Value.of(constant)
value = ValueDescription.of(constant)
except Exception as error:
message = f"Constant {repr(constant)} is not supported"
raise ValueError(message) from error
@@ -65,8 +65,8 @@ class Node:
@staticmethod
def generic(
name: str,
inputs: List[Value],
output: Value,
inputs: List[ValueDescription],
output: ValueDescription,
operation: Callable,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
@@ -79,10 +79,10 @@ class Node:
name (str):
name of the operation
inputs (List[Value]):
inputs (List[ValueDescription]):
inputs to the operation
output (Value):
output (ValueDescription):
output of the operation
operation (Callable):
@@ -122,7 +122,7 @@ class Node:
)
@staticmethod
def input(name: str, value: Value) -> "Node":
def input(name: str, value: ValueDescription) -> "Node":
"""
Create an Operation.Input node.
@@ -142,8 +142,8 @@ class Node:
def __init__(
self,
inputs: List[Value],
output: Value,
inputs: List[ValueDescription],
output: ValueDescription,
operation: Operation,
evaluator: Callable,
properties: Optional[Dict[str, Any]] = None,
@@ -198,7 +198,7 @@ class Node:
for arg, input_ in zip(args, self.inputs):
try:
arg_value = Value.of(arg)
arg_value = ValueDescription.of(arg)
except Exception as error:
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
message = f"{generic_error_message()} failed because {arg_str} is not valid"

View File

@@ -14,7 +14,7 @@ from ..dtypes import BaseDataType, Float, Integer
from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..representation.utils import format_indexing_element
from ..values import Value
from ..values import ValueDescription
class Tracer:
@@ -24,7 +24,7 @@ class Tracer:
computation: Node
input_tracers: List["Tracer"]
output: Value
output: ValueDescription
# property to keep track of assignments
last_version: Optional["Tracer"] = None
@@ -34,7 +34,9 @@ class Tracer:
_is_direct: bool = False
@staticmethod
def trace(function: Callable, parameters: Dict[str, Value], is_direct: bool = False) -> Graph:
def trace(
function: Callable, parameters: Dict[str, ValueDescription], is_direct: bool = False
) -> Graph:
"""
Trace `function` and create the `Graph` that represents it.
@@ -42,7 +44,7 @@ class Tracer:
function (Callable):
function to trace
parameters (Dict[str, Value]):
parameters (Dict[str, ValueDescription]):
parameters of function to trace
e.g. parameter x is an EncryptedScalar holding a 7-bit UnsignedInteger
@@ -414,7 +416,7 @@ class Tracer:
for arg in args:
extract_tracers(arg, tracers)
output_value = Value.of(evaluation)
output_value = ValueDescription.of(evaluation)
output_value.is_encrypted = any(tracer.output.is_encrypted for tracer in tracers)
if Tracer._is_direct and isinstance(output_value.dtype, Integer):
@@ -646,7 +648,7 @@ class Tracer:
)
output_value = deepcopy(self.output)
output_value.dtype = Value.of(dtype(0)).dtype # type: ignore
output_value.dtype = ValueDescription.of(dtype(0)).dtype # type: ignore
if np.issubdtype(dtype, np.integer):

View File

@@ -5,7 +5,7 @@ Declaration of type annotation.
from typing import Any
from ..dtypes import Float, SignedInteger, UnsignedInteger
from ..values import Value
from ..values import ValueDescription
from .tracer import ScalarAnnotation, TensorAnnotation
# pylint: disable=function-redefined,invalid-name,too-many-lines,using-constant-test
@@ -1219,4 +1219,4 @@ class tensor(TensorAnnotation): # type: ignore
if not all(isinstance(x, int) for x in shape):
raise ValueError("Tensor annotation shape elements must be 'int'")
return Value(dtype=annotation.dtype, shape=shape, is_encrypted=False)
return ValueDescription(dtype=annotation.dtype, shape=shape, is_encrypted=False)

View File

@@ -4,4 +4,4 @@ Define the available values and their semantics.
from .scalar import ClearScalar, EncryptedScalar
from .tensor import ClearTensor, EncryptedTensor
from .value import Value
from .value_description import ValueDescription

View File

@@ -3,10 +3,10 @@ Declaration of `ClearScalar` and `EncryptedScalar` wrappers.
"""
from ..dtypes import BaseDataType
from .value import Value
from .value_description import ValueDescription
def clear_scalar_builder(dtype: BaseDataType) -> Value:
def clear_scalar_builder(dtype: BaseDataType) -> ValueDescription:
"""
Build a clear scalar value.
@@ -15,17 +15,17 @@ def clear_scalar_builder(dtype: BaseDataType) -> Value:
dtype of the value
Returns:
Value:
clear scalar value with given dtype
ValueDescription:
clear scalar value description with given dtype
"""
return Value(dtype=dtype, shape=(), is_encrypted=False)
return ValueDescription(dtype=dtype, shape=(), is_encrypted=False)
ClearScalar = clear_scalar_builder
def encrypted_scalar_builder(dtype: BaseDataType) -> Value:
def encrypted_scalar_builder(dtype: BaseDataType) -> ValueDescription:
"""
Build an encrypted scalar value.
@@ -34,11 +34,11 @@ def encrypted_scalar_builder(dtype: BaseDataType) -> Value:
dtype of the value
Returns:
Value:
encrypted scalar value with given dtype
ValueDescription:
encrypted scalar value description with given dtype
"""
return Value(dtype=dtype, shape=(), is_encrypted=True)
return ValueDescription(dtype=dtype, shape=(), is_encrypted=True)
EncryptedScalar = encrypted_scalar_builder

View File

@@ -5,10 +5,10 @@ Declaration of `ClearTensor` and `EncryptedTensor` wrappers.
from typing import Tuple
from ..dtypes import BaseDataType
from .value import Value
from .value_description import ValueDescription
def clear_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Value:
def clear_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> ValueDescription:
"""
Build a clear tensor value.
@@ -20,17 +20,17 @@ def clear_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Value:
shape of the value
Returns:
Value:
clear tensor value with given dtype and shape
ValueDescription:
clear tensor value description with given dtype and shape
"""
return Value(dtype=dtype, shape=shape, is_encrypted=False)
return ValueDescription(dtype=dtype, shape=shape, is_encrypted=False)
ClearTensor = clear_tensor_builder
def encrypted_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Value:
def encrypted_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> ValueDescription:
"""
Build an encrypted tensor value.
@@ -42,11 +42,11 @@ def encrypted_tensor_builder(dtype: BaseDataType, shape: Tuple[int, ...]) -> Val
shape of the value
Returns:
Value:
encrypted tensor value with given dtype and shape
ValueDescription:
encrypted tensor value description with given dtype and shape
"""
return Value(dtype=dtype, shape=shape, is_encrypted=True)
return ValueDescription(dtype=dtype, shape=shape, is_encrypted=True)
EncryptedTensor = encrypted_tensor_builder

View File

@@ -1,5 +1,5 @@
"""
Declaration of `Value` class.
Declaration of `ValueDescription` class.
"""
from typing import Any, Tuple
@@ -9,9 +9,9 @@ import numpy as np
from ..dtypes import BaseDataType, Float, Integer, UnsignedInteger
class Value:
class ValueDescription:
"""
Value class, to combine data type, shape, and encryption status into a single object.
ValueDescription class, to combine data type, shape, and encryption status into a single object.
"""
dtype: BaseDataType
@@ -19,46 +19,46 @@ class Value:
is_encrypted: bool
@staticmethod
def of(value: Any, is_encrypted: bool = False) -> "Value": # pylint: disable=invalid-name
def of(value: Any, is_encrypted: bool = False) -> "ValueDescription":
"""
Get the `Value` that can represent `value`.
Get the `ValueDescription` that can represent `value`.
Args:
value (Any):
value that needs to be represented
is_encrypted (bool, default = False):
whether the resulting `Value` is encrypted or not
whether the resulting `ValueDescription` is encrypted or not
Returns:
Value:
`Value` that can represent `value`
ValueDescription:
`ValueDescription` that can represent `value`
Raises:
ValueError:
if `value` cannot be represented by `Value`
if `value` cannot be represented by `ValueDescription`
"""
# pylint: disable=too-many-branches,too-many-return-statements
if isinstance(value, (bool, np.bool_)):
return Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=is_encrypted)
return ValueDescription(dtype=UnsignedInteger(1), shape=(), is_encrypted=is_encrypted)
if isinstance(value, (int, np.integer)):
return Value(
return ValueDescription(
dtype=Integer.that_can_represent(value),
shape=(),
is_encrypted=is_encrypted,
)
if isinstance(value, (float, np.float64)):
return Value(dtype=Float(64), shape=(), is_encrypted=is_encrypted)
return ValueDescription(dtype=Float(64), shape=(), is_encrypted=is_encrypted)
if isinstance(value, np.float32):
return Value(dtype=Float(32), shape=(), is_encrypted=is_encrypted)
return ValueDescription(dtype=Float(32), shape=(), is_encrypted=is_encrypted)
if isinstance(value, np.float16):
return Value(dtype=Float(16), shape=(), is_encrypted=is_encrypted)
return ValueDescription(dtype=Float(16), shape=(), is_encrypted=is_encrypted)
if isinstance(value, list):
try:
@@ -70,25 +70,33 @@ class Value:
if isinstance(value, np.ndarray):
if np.issubdtype(value.dtype, np.bool_):
return Value(dtype=UnsignedInteger(1), shape=value.shape, is_encrypted=is_encrypted)
return ValueDescription(
dtype=UnsignedInteger(1), shape=value.shape, is_encrypted=is_encrypted
)
if np.issubdtype(value.dtype, np.integer):
return Value(
return ValueDescription(
dtype=Integer.that_can_represent(value),
shape=value.shape,
is_encrypted=is_encrypted,
)
if np.issubdtype(value.dtype, np.float64):
return Value(dtype=Float(64), shape=value.shape, is_encrypted=is_encrypted)
return ValueDescription(
dtype=Float(64), shape=value.shape, is_encrypted=is_encrypted
)
if np.issubdtype(value.dtype, np.float32):
return Value(dtype=Float(32), shape=value.shape, is_encrypted=is_encrypted)
return ValueDescription(
dtype=Float(32), shape=value.shape, is_encrypted=is_encrypted
)
if np.issubdtype(value.dtype, np.float16):
return Value(dtype=Float(16), shape=value.shape, is_encrypted=is_encrypted)
return ValueDescription(
dtype=Float(16), shape=value.shape, is_encrypted=is_encrypted
)
message = f"Value cannot represent {repr(value)}"
message = f"Concrete cannot represent {repr(value)}"
raise ValueError(message)
# pylint: enable=too-many-branches,too-many-return-statements
@@ -100,7 +108,7 @@ class Value:
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Value)
isinstance(other, ValueDescription)
and self.dtype == other.dtype
and self.shape == other.shape
and self.is_encrypted == other.is_encrypted

View File

@@ -410,13 +410,17 @@ def test_direct_graph_integer_range(helpers):
# pylint: disable=import-outside-toplevel
from concrete.fhe.dtypes import Integer
from concrete.fhe.values import Value
from concrete.fhe.values import ValueDescription
# pylint: enable=import-outside-toplevel
circuit = fhe.Compiler.assemble(
lambda x: x,
{"x": Value(dtype=Integer(is_signed=False, bit_width=8), shape=(), is_encrypted=True)},
{
"x": ValueDescription(
dtype=Integer(is_signed=False, bit_width=8), shape=(), is_encrypted=True
)
},
configuration=helpers.configuration(),
)
assert circuit.graph.integer_range() is None

View File

@@ -7,7 +7,7 @@ import pytest
from concrete.fhe.dtypes import UnsignedInteger
from concrete.fhe.representation import Node
from concrete.fhe.values import ClearScalar, EncryptedScalar, EncryptedTensor, Value
from concrete.fhe.values import ClearScalar, EncryptedScalar, EncryptedTensor, ValueDescription
@pytest.mark.parametrize(
@@ -44,8 +44,8 @@ def test_node_bad_constant(constant, expected_error, expected_message):
pytest.param(
Node.generic(
name="add",
inputs=[Value.of(4), Value.of(10, is_encrypted=True)],
output=Value.of(14),
inputs=[ValueDescription.of(4), ValueDescription.of(10, is_encrypted=True)],
output=ValueDescription.of(14),
operation=lambda x, y: x + y,
),
["abc"],
@@ -56,8 +56,8 @@ def test_node_bad_constant(constant, expected_error, expected_message):
pytest.param(
Node.generic(
name="add",
inputs=[Value.of(4), Value.of(10, is_encrypted=True)],
output=Value.of(14),
inputs=[ValueDescription.of(4), ValueDescription.of(10, is_encrypted=True)],
output=ValueDescription.of(14),
operation=lambda x, y: x + y,
),
["abc", "def"],
@@ -68,8 +68,8 @@ def test_node_bad_constant(constant, expected_error, expected_message):
pytest.param(
Node.generic(
name="add",
inputs=[Value.of([3, 4]), Value.of(10, is_encrypted=True)],
output=Value.of([13, 14]),
inputs=[ValueDescription.of([3, 4]), ValueDescription.of(10, is_encrypted=True)],
output=ValueDescription.of([13, 14]),
operation=lambda x, y: x + y,
),
[[1, 2, 3, 4], 10],
@@ -81,7 +81,7 @@ def test_node_bad_constant(constant, expected_error, expected_message):
Node.generic(
name="unknown",
inputs=[],
output=Value.of(10),
output=ValueDescription.of(10),
operation=lambda: "abc",
),
[],
@@ -93,7 +93,7 @@ def test_node_bad_constant(constant, expected_error, expected_message):
Node.generic(
name="unknown",
inputs=[],
output=Value.of(10),
output=ValueDescription.of(10),
operation=lambda: np.array(["abc", "def"]),
),
[],
@@ -106,7 +106,7 @@ def test_node_bad_constant(constant, expected_error, expected_message):
Node.generic(
name="unknown",
inputs=[],
output=Value.of(10),
output=ValueDescription.of(10),
operation=lambda: [1, (), 3],
),
[],
@@ -118,7 +118,7 @@ def test_node_bad_constant(constant, expected_error, expected_message):
Node.generic(
name="unknown",
inputs=[],
output=Value.of(10),
output=ValueDescription.of(10),
operation=lambda: [1, 2, 3],
),
[],

View File

@@ -6,7 +6,13 @@ import numpy as np
import pytest
from concrete.fhe.dtypes import Float, SignedInteger, UnsignedInteger
from concrete.fhe.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, Value
from concrete.fhe.values import (
ClearScalar,
ClearTensor,
EncryptedScalar,
EncryptedTensor,
ValueDescription,
)
@pytest.mark.parametrize(
@@ -15,87 +21,87 @@ from concrete.fhe.values import ClearScalar, ClearTensor, EncryptedScalar, Encry
pytest.param(
True,
True,
Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=True),
ValueDescription(dtype=UnsignedInteger(1), shape=(), is_encrypted=True),
),
pytest.param(
True,
False,
Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=False),
ValueDescription(dtype=UnsignedInteger(1), shape=(), is_encrypted=False),
),
pytest.param(
False,
True,
Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=True),
ValueDescription(dtype=UnsignedInteger(1), shape=(), is_encrypted=True),
),
pytest.param(
False,
False,
Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=False),
ValueDescription(dtype=UnsignedInteger(1), shape=(), is_encrypted=False),
),
pytest.param(
0,
False,
Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=False),
ValueDescription(dtype=UnsignedInteger(1), shape=(), is_encrypted=False),
),
pytest.param(
np.int32(0),
True,
Value(dtype=UnsignedInteger(1), shape=(), is_encrypted=True),
ValueDescription(dtype=UnsignedInteger(1), shape=(), is_encrypted=True),
),
pytest.param(
0.0,
False,
Value(dtype=Float(64), shape=(), is_encrypted=False),
ValueDescription(dtype=Float(64), shape=(), is_encrypted=False),
),
pytest.param(
np.float64(0.0),
True,
Value(dtype=Float(64), shape=(), is_encrypted=True),
ValueDescription(dtype=Float(64), shape=(), is_encrypted=True),
),
pytest.param(
np.float32(0.0),
False,
Value(dtype=Float(32), shape=(), is_encrypted=False),
ValueDescription(dtype=Float(32), shape=(), is_encrypted=False),
),
pytest.param(
np.float16(0.0),
True,
Value(dtype=Float(16), shape=(), is_encrypted=True),
ValueDescription(dtype=Float(16), shape=(), is_encrypted=True),
),
pytest.param(
[True, False, True],
False,
Value(dtype=UnsignedInteger(1), shape=(3,), is_encrypted=False),
ValueDescription(dtype=UnsignedInteger(1), shape=(3,), is_encrypted=False),
),
pytest.param(
[True, False, True],
True,
Value(dtype=UnsignedInteger(1), shape=(3,), is_encrypted=True),
ValueDescription(dtype=UnsignedInteger(1), shape=(3,), is_encrypted=True),
),
pytest.param(
[0, 3, 1, 2],
False,
Value(dtype=UnsignedInteger(2), shape=(4,), is_encrypted=False),
ValueDescription(dtype=UnsignedInteger(2), shape=(4,), is_encrypted=False),
),
pytest.param(
np.array([0, 3, 1, 2], dtype=np.int32),
True,
Value(dtype=UnsignedInteger(2), shape=(4,), is_encrypted=True),
ValueDescription(dtype=UnsignedInteger(2), shape=(4,), is_encrypted=True),
),
pytest.param(
np.array([0.2, 3.4, 1.5, 2.0], dtype=np.float64),
False,
Value(dtype=Float(64), shape=(4,), is_encrypted=False),
ValueDescription(dtype=Float(64), shape=(4,), is_encrypted=False),
),
pytest.param(
np.array([0.2, 3.4, 1.5, 2.0], dtype=np.float32),
True,
Value(dtype=Float(32), shape=(4,), is_encrypted=True),
ValueDescription(dtype=Float(32), shape=(4,), is_encrypted=True),
),
pytest.param(
np.array([0.2, 3.4, 1.5, 2.0], dtype=np.float16),
False,
Value(dtype=Float(16), shape=(4,), is_encrypted=False),
ValueDescription(dtype=Float(16), shape=(4,), is_encrypted=False),
),
],
)
@@ -104,7 +110,7 @@ def test_value_of(value, is_encrypted, expected_result):
Test `of` function of `Value` class.
"""
assert Value.of(value, is_encrypted) == expected_result
assert ValueDescription.of(value, is_encrypted) == expected_result
@pytest.mark.parametrize(
@@ -114,13 +120,13 @@ def test_value_of(value, is_encrypted, expected_result):
"abc",
False,
ValueError,
"Value cannot represent 'abc'",
"Concrete cannot represent 'abc'",
),
pytest.param(
[1, (), 3],
False,
ValueError,
"Value cannot represent [1, (), 3]",
"Concrete cannot represent [1, (), 3]",
),
],
)
@@ -130,7 +136,7 @@ def test_value_bad_of(value, is_encrypted, expected_error, expected_message):
"""
with pytest.raises(expected_error) as excinfo:
Value.of(value, is_encrypted)
ValueDescription.of(value, is_encrypted)
assert str(excinfo.value) == expected_message
@@ -140,22 +146,22 @@ def test_value_bad_of(value, is_encrypted, expected_error, expected_message):
[
pytest.param(
ClearScalar(SignedInteger(5)),
Value(dtype=SignedInteger(5), shape=(), is_encrypted=False),
ValueDescription(dtype=SignedInteger(5), shape=(), is_encrypted=False),
True,
),
pytest.param(
ClearTensor(UnsignedInteger(5), shape=(3, 2)),
Value(dtype=UnsignedInteger(5), shape=(3, 2), is_encrypted=False),
ValueDescription(dtype=UnsignedInteger(5), shape=(3, 2), is_encrypted=False),
True,
),
pytest.param(
EncryptedScalar(SignedInteger(5)),
Value(dtype=SignedInteger(5), shape=(), is_encrypted=True),
ValueDescription(dtype=SignedInteger(5), shape=(), is_encrypted=True),
True,
),
pytest.param(
EncryptedTensor(UnsignedInteger(5), shape=(3, 2)),
Value(dtype=UnsignedInteger(5), shape=(3, 2), is_encrypted=True),
ValueDescription(dtype=UnsignedInteger(5), shape=(3, 2), is_encrypted=True),
True,
),
pytest.param(