mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: rename 'data_type' field of 'BaseValue' to 'dtype'
This commit is contained in:
@@ -42,7 +42,7 @@ def _check_input_coherency(
|
||||
base_value = base_value_class(is_encrypted=parameter_base_value.is_encrypted)
|
||||
|
||||
if base_value.shape != parameter_base_value.shape or not is_data_type_compatible_with(
|
||||
base_value.data_type, parameter_base_value.data_type
|
||||
base_value.dtype, parameter_base_value.dtype
|
||||
):
|
||||
warnings.append(
|
||||
f"expected {str(parameter_base_value)} "
|
||||
|
||||
@@ -31,8 +31,8 @@ def ir_nodes_has_integer_input_and_output(node: IntermediateNode) -> bool:
|
||||
Returns:
|
||||
bool: True if all input and output values hold Integers
|
||||
"""
|
||||
return all(isinstance(x.data_type, Integer) for x in node.inputs) and all(
|
||||
isinstance(x.data_type, Integer) for x in node.outputs
|
||||
return all(isinstance(x.dtype, Integer) for x in node.inputs) and all(
|
||||
isinstance(x.dtype, Integer) for x in node.outputs
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> boo
|
||||
"""
|
||||
return (
|
||||
value_is_encrypted_scalar_integer(value_to_check)
|
||||
and not cast(Integer, value_to_check.data_type).is_signed
|
||||
and not cast(Integer, value_to_check.dtype).is_signed
|
||||
)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
|
||||
bool: True if the passed value_to_check is a ScalarValue of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check, ScalarValue) and isinstance(
|
||||
value_to_check.data_type, INTEGER_TYPES
|
||||
value_to_check.dtype, INTEGER_TYPES
|
||||
)
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> boo
|
||||
"""
|
||||
return (
|
||||
value_is_encrypted_tensor_integer(value_to_check)
|
||||
and not cast(Integer, value_to_check.data_type).is_signed
|
||||
and not cast(Integer, value_to_check.dtype).is_signed
|
||||
)
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ def value_is_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
bool: True if the passed value_to_check is a TensorValue of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check, TensorValue) and isinstance(
|
||||
value_to_check.data_type, INTEGER_TYPES
|
||||
value_to_check.dtype, INTEGER_TYPES
|
||||
)
|
||||
|
||||
|
||||
@@ -216,7 +216,7 @@ def mix_scalar_values_determine_holding_dtype(
|
||||
isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue"
|
||||
)
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
|
||||
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
|
||||
mixed_value: ScalarValue
|
||||
|
||||
if value1.is_encrypted or value2.is_encrypted:
|
||||
@@ -261,13 +261,13 @@ def mix_tensor_values_determine_holding_dtype(
|
||||
),
|
||||
)
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
|
||||
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
|
||||
shape = value1.shape
|
||||
|
||||
if value1.is_encrypted or value2.is_encrypted:
|
||||
mixed_value = EncryptedTensor(data_type=holding_type, shape=shape)
|
||||
mixed_value = EncryptedTensor(dtype=holding_type, shape=shape)
|
||||
else:
|
||||
mixed_value = ClearTensor(data_type=holding_type, shape=shape)
|
||||
mixed_value = ClearTensor(dtype=holding_type, shape=shape)
|
||||
|
||||
return mixed_value
|
||||
|
||||
@@ -362,10 +362,10 @@ def get_base_value_for_python_constant_data(
|
||||
assert len(constant_data) > 0
|
||||
constant_shape = (len(constant_data),)
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return partial(TensorValue, data_type=constant_data_type, shape=constant_shape)
|
||||
return partial(TensorValue, dtype=constant_data_type, shape=constant_shape)
|
||||
|
||||
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
|
||||
return partial(ScalarValue, data_type=constant_data_type)
|
||||
return partial(ScalarValue, dtype=constant_data_type)
|
||||
|
||||
|
||||
def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]):
|
||||
|
||||
@@ -52,7 +52,7 @@ def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.AddEintIntOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
@@ -63,7 +63,7 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.AddEintOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
@@ -87,7 +87,7 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.SubIntEintOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
@@ -116,7 +116,7 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.MulEintIntOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
@@ -126,7 +126,7 @@ def constant(node, _, __, ctx):
|
||||
"""Convert a constant inputs."""
|
||||
if not value_is_clear_scalar_integer(node.outputs[0]):
|
||||
raise TypeError("Don't support non-integer constants")
|
||||
dtype = cast(Integer, node.outputs[0].data_type)
|
||||
dtype = cast(Integer, node.outputs[0].dtype)
|
||||
if dtype.is_signed:
|
||||
raise TypeError("Don't support signed constant integer")
|
||||
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
|
||||
@@ -145,7 +145,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
x_node = preds[0]
|
||||
x = ir_to_mlir_node[x_node]
|
||||
table = node.get_table()
|
||||
out_dtype = cast(Integer, node.outputs[0].data_type)
|
||||
out_dtype = cast(Integer, node.outputs[0].dtype)
|
||||
# Create table
|
||||
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
|
||||
tensor_lut = std_dialect.ConstantOp(
|
||||
@@ -182,7 +182,7 @@ def dot(node, preds, ir_to_mlir_node, ctx):
|
||||
lhs_node, rhs_node = rhs_node, lhs_node
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.Dot(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
@@ -101,16 +101,14 @@ class MLIRConverter:
|
||||
corresponding MLIR type
|
||||
"""
|
||||
if value_is_encrypted_scalar_unsigned_integer(value):
|
||||
return self._get_scalar_integer_type(
|
||||
cast(Integer, value.data_type).bit_width, True, False
|
||||
)
|
||||
return self._get_scalar_integer_type(cast(Integer, value.dtype).bit_width, True, False)
|
||||
if value_is_clear_scalar_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
dtype = cast(Integer, value.dtype)
|
||||
return self._get_scalar_integer_type(
|
||||
dtype.bit_width, is_encrypted=False, is_signed=dtype.is_signed
|
||||
)
|
||||
if value_is_encrypted_tensor_unsigned_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
dtype = cast(Integer, value.dtype)
|
||||
return self._get_tensor_type(
|
||||
dtype.bit_width,
|
||||
is_encrypted=True,
|
||||
@@ -118,7 +116,7 @@ class MLIRConverter:
|
||||
shape=cast(values.TensorValue, value).shape,
|
||||
)
|
||||
if value_is_clear_tensor_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
dtype = cast(Integer, value.dtype)
|
||||
return self._get_tensor_type(
|
||||
dtype.bit_width,
|
||||
is_encrypted=False,
|
||||
|
||||
@@ -28,7 +28,7 @@ def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
|
||||
"""
|
||||
return all(
|
||||
all(
|
||||
value_is_scalar_integer(out) and not cast(Integer, out.data_type).is_signed
|
||||
value_is_scalar_integer(out) and not cast(Integer, out.dtype).is_signed
|
||||
for out in out_node.outputs
|
||||
)
|
||||
for out_node in op_graph.output_nodes.values()
|
||||
@@ -45,11 +45,11 @@ def _set_all_bit_width(op_graph: OPGraph, p: int):
|
||||
for node in op_graph.graph.nodes:
|
||||
for value in node.outputs + node.inputs:
|
||||
if value_is_clear_scalar_integer(value) or value_is_clear_tensor_integer(value):
|
||||
value.data_type.bit_width = p + 1
|
||||
value.dtype.bit_width = p + 1
|
||||
elif value_is_encrypted_scalar_integer(value) or value_is_encrypted_tensor_integer(
|
||||
value
|
||||
):
|
||||
value.data_type.bit_width = p
|
||||
value.dtype.bit_width = p
|
||||
|
||||
|
||||
def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
@@ -63,7 +63,7 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
for node in op_graph.graph.nodes:
|
||||
for value_out in node.outputs:
|
||||
if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out):
|
||||
current_node_out_bit_width = value_out.data_type.bit_width - 1
|
||||
current_node_out_bit_width = value_out.dtype.bit_width - 1
|
||||
else:
|
||||
|
||||
assert_true(
|
||||
@@ -71,7 +71,7 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
or value_is_encrypted_tensor_integer(value_out)
|
||||
)
|
||||
|
||||
current_node_out_bit_width = value_out.data_type.bit_width
|
||||
current_node_out_bit_width = value_out.dtype.bit_width
|
||||
|
||||
max_bit_width = max(max_bit_width, current_node_out_bit_width)
|
||||
|
||||
@@ -106,7 +106,7 @@ def extend_direct_lookup_tables(op_graph: OPGraph):
|
||||
for node in op_graph.graph.nodes:
|
||||
if isinstance(node, ArbitraryFunction) and node.op_name == "TLU":
|
||||
table = node.op_kwargs["table"]
|
||||
bit_width = cast(Integer, node.inputs[0].data_type).bit_width
|
||||
bit_width = cast(Integer, node.inputs[0].dtype).bit_width
|
||||
expected_length = 2 ** bit_width
|
||||
|
||||
# TODO: remove no cover once the table length workaround is removed
|
||||
|
||||
@@ -196,7 +196,7 @@ class OPGraph:
|
||||
if not isinstance(node, Input):
|
||||
for output_value in node.outputs:
|
||||
if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer):
|
||||
output_value.data_type = make_integer_to_hold(
|
||||
output_value.dtype = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
else:
|
||||
@@ -208,8 +208,8 @@ class OPGraph:
|
||||
f"min_bound: {min_data_type}, max_bound: {max_data_type}"
|
||||
),
|
||||
)
|
||||
output_value.data_type = Float(64)
|
||||
output_value.data_type.underlying_type_constructor = data_type_constructor
|
||||
output_value.dtype = Float(64)
|
||||
output_value.dtype.underlying_type_constructor = data_type_constructor
|
||||
else:
|
||||
# Currently variable inputs are only allowed to be integers
|
||||
custom_assert(
|
||||
@@ -220,10 +220,10 @@ class OPGraph:
|
||||
f"max: {max_bound} ({type(max_bound)})"
|
||||
),
|
||||
)
|
||||
node.inputs[0].data_type = make_integer_to_hold(
|
||||
node.inputs[0].dtype = make_integer_to_hold(
|
||||
(min_bound, max_bound), force_signed=False
|
||||
)
|
||||
node.inputs[0].data_type.underlying_type_constructor = data_type_constructor
|
||||
node.inputs[0].dtype.underlying_type_constructor = data_type_constructor
|
||||
|
||||
node.outputs[0] = deepcopy(node.inputs[0])
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[
|
||||
terminal_node
|
||||
],
|
||||
deepcopy(terminal_node.outputs[0].data_type),
|
||||
deepcopy(terminal_node.outputs[0].dtype),
|
||||
op_kwargs={
|
||||
"float_op_subgraph": float_op_subgraph,
|
||||
"terminal_node": terminal_node,
|
||||
@@ -197,13 +197,13 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
|
||||
def is_float_to_single_int_node(node: IntermediateNode) -> bool:
|
||||
return (
|
||||
any(isinstance(input_.data_type, Float) for input_ in node.inputs)
|
||||
any(isinstance(input_.dtype, Float) for input_ in node.inputs)
|
||||
and len(node.outputs) == 1
|
||||
and isinstance(node.outputs[0].data_type, Integer)
|
||||
and isinstance(node.outputs[0].dtype, Integer)
|
||||
)
|
||||
|
||||
def single_int_output_node(node: IntermediateNode) -> bool:
|
||||
return len(node.outputs) == 1 and isinstance(node.outputs[0].data_type, Integer)
|
||||
return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer)
|
||||
|
||||
float_subgraphs_terminal_nodes = (
|
||||
node
|
||||
|
||||
@@ -242,21 +242,21 @@ class ArbitraryFunction(IntermediateNode):
|
||||
"""
|
||||
# Check the input is an unsigned integer to be able to build a table
|
||||
assert isinstance(
|
||||
self.inputs[0].data_type, Integer
|
||||
self.inputs[0].dtype, Integer
|
||||
), "get_table only works for an unsigned Integer input"
|
||||
assert not self.inputs[
|
||||
0
|
||||
].data_type.is_signed, "get_table only works for an unsigned Integer input"
|
||||
].dtype.is_signed, "get_table only works for an unsigned Integer input"
|
||||
|
||||
type_constructor = self.inputs[0].data_type.underlying_type_constructor
|
||||
type_constructor = self.inputs[0].dtype.underlying_type_constructor
|
||||
if type_constructor is None:
|
||||
logger.info(
|
||||
f"{self.__class__.__name__} input data type constructor was None, defaulting to int"
|
||||
)
|
||||
type_constructor = int
|
||||
|
||||
min_input_range = self.inputs[0].data_type.min_value()
|
||||
max_input_range = self.inputs[0].data_type.max_value() + 1
|
||||
min_input_range = self.inputs[0].dtype.min_value()
|
||||
max_input_range = self.inputs[0].dtype.max_value() + 1
|
||||
|
||||
table = [
|
||||
self.evaluate({0: type_constructor(input_value)})
|
||||
|
||||
@@ -9,11 +9,11 @@ from ..data_types.base import BaseDataType
|
||||
class BaseValue(ABC):
|
||||
"""Abstract base class to represent any kind of value in a program."""
|
||||
|
||||
data_type: BaseDataType
|
||||
dtype: BaseDataType
|
||||
_is_encrypted: bool
|
||||
|
||||
def __init__(self, data_type: BaseDataType, is_encrypted: bool) -> None:
|
||||
self.data_type = deepcopy(data_type)
|
||||
def __init__(self, dtype: BaseDataType, is_encrypted: bool) -> None:
|
||||
self.dtype = deepcopy(dtype)
|
||||
self._is_encrypted = is_encrypted
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
@@ -21,7 +21,7 @@ class BaseValue(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.data_type == other.data_type
|
||||
return isinstance(other, self.__class__) and self.dtype == other.dtype
|
||||
|
||||
@property
|
||||
def is_encrypted(self) -> bool:
|
||||
|
||||
@@ -14,7 +14,7 @@ class ScalarValue(BaseValue):
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover
|
||||
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
|
||||
return f"{encrypted_str}Scalar<{self.data_type!r}>"
|
||||
return f"{encrypted_str}Scalar<{self.dtype!r}>"
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
@@ -26,28 +26,28 @@ class ScalarValue(BaseValue):
|
||||
return ()
|
||||
|
||||
|
||||
def make_clear_scalar(data_type: BaseDataType) -> ScalarValue:
|
||||
def make_clear_scalar(dtype: BaseDataType) -> ScalarValue:
|
||||
"""Create a clear ScalarValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the value.
|
||||
dtype (BaseDataType): The data type for the value.
|
||||
|
||||
Returns:
|
||||
ScalarValue: The corresponding ScalarValue.
|
||||
"""
|
||||
return ScalarValue(data_type=data_type, is_encrypted=False)
|
||||
return ScalarValue(dtype=dtype, is_encrypted=False)
|
||||
|
||||
|
||||
def make_encrypted_scalar(data_type: BaseDataType) -> ScalarValue:
|
||||
def make_encrypted_scalar(dtype: BaseDataType) -> ScalarValue:
|
||||
"""Create an encrypted ScalarValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the value.
|
||||
dtype (BaseDataType): The data type for the value.
|
||||
|
||||
Returns:
|
||||
ScalarValue: The corresponding ScalarValue.
|
||||
"""
|
||||
return ScalarValue(data_type=data_type, is_encrypted=True)
|
||||
return ScalarValue(dtype=dtype, is_encrypted=True)
|
||||
|
||||
|
||||
ClearScalar = make_clear_scalar
|
||||
|
||||
@@ -16,11 +16,11 @@ class TensorValue(BaseValue):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_type: BaseDataType,
|
||||
dtype: BaseDataType,
|
||||
is_encrypted: bool,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
) -> None:
|
||||
super().__init__(data_type, is_encrypted)
|
||||
super().__init__(dtype, is_encrypted)
|
||||
# Managing tensors as in numpy, no shape or () is treated as a 0-D array of size 1
|
||||
self._shape = shape if shape is not None else ()
|
||||
self._ndim = len(self._shape)
|
||||
@@ -37,7 +37,7 @@ class TensorValue(BaseValue):
|
||||
|
||||
def __str__(self) -> str:
|
||||
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
|
||||
return f"{encrypted_str}Tensor<{str(self.data_type)}, shape={self.shape}>"
|
||||
return f"{encrypted_str}Tensor<{str(self.dtype)}, shape={self.shape}>"
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
@@ -68,35 +68,35 @@ class TensorValue(BaseValue):
|
||||
|
||||
|
||||
def make_clear_tensor(
|
||||
data_type: BaseDataType,
|
||||
dtype: BaseDataType,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
) -> TensorValue:
|
||||
"""Create a clear TensorValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the tensor.
|
||||
dtype (BaseDataType): The data type for the tensor.
|
||||
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(data_type=data_type, is_encrypted=False, shape=shape)
|
||||
return TensorValue(dtype=dtype, is_encrypted=False, shape=shape)
|
||||
|
||||
|
||||
def make_encrypted_tensor(
|
||||
data_type: BaseDataType,
|
||||
dtype: BaseDataType,
|
||||
shape: Optional[Tuple[int, ...]] = None,
|
||||
) -> TensorValue:
|
||||
"""Create an encrypted TensorValue.
|
||||
|
||||
Args:
|
||||
data_type (BaseDataType): The data type for the tensor.
|
||||
dtype (BaseDataType): The data type for the tensor.
|
||||
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
|
||||
|
||||
Returns:
|
||||
TensorValue: The corresponding TensorValue.
|
||||
"""
|
||||
return TensorValue(data_type=data_type, is_encrypted=True, shape=shape)
|
||||
return TensorValue(dtype=dtype, is_encrypted=True, shape=shape)
|
||||
|
||||
|
||||
ClearTensor = make_clear_tensor
|
||||
|
||||
@@ -125,9 +125,9 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
# this loop will determine the number of possible inputs of the function
|
||||
# if a function have a single 3-bit input, for example, `inputset_size_upper_limit` will be 8
|
||||
for parameter_value in function_parameters.values():
|
||||
if isinstance(parameter_value.data_type, Integer):
|
||||
if isinstance(parameter_value.dtype, Integer):
|
||||
# multiple parameter bit-widths are multiplied as they can be combined into an input
|
||||
inputset_size_upper_limit *= 2 ** parameter_value.data_type.bit_width
|
||||
inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width
|
||||
|
||||
# if the upper limit of the inputset size goes above 10,
|
||||
# break the loop as we will require at least 10 inputs in this case
|
||||
|
||||
@@ -172,9 +172,9 @@ def get_base_value_for_numpy_or_python_constant_data(
|
||||
|
||||
base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data)
|
||||
if isinstance(constant_data, numpy.ndarray):
|
||||
constant_data_value = partial(TensorValue, data_type=base_dtype, shape=constant_data.shape)
|
||||
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=constant_data.shape)
|
||||
elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
|
||||
constant_data_value = partial(ScalarValue, data_type=base_dtype)
|
||||
constant_data_value = partial(ScalarValue, dtype=base_dtype)
|
||||
else:
|
||||
constant_data_value = get_base_value_for_python_constant_data(constant_data)
|
||||
return constant_data_value
|
||||
|
||||
@@ -129,7 +129,7 @@ class NPTracer(BaseTracer):
|
||||
@staticmethod
|
||||
def _manage_dtypes(ufunc: Union[numpy.ufunc, Callable], *input_tracers: BaseTracer):
|
||||
output_dtypes = get_numpy_function_output_dtype(
|
||||
ufunc, [input_tracer.output.data_type for input_tracer in input_tracers]
|
||||
ufunc, [input_tracer.output.dtype for input_tracer in input_tracers]
|
||||
)
|
||||
common_output_dtypes = [
|
||||
convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes
|
||||
|
||||
@@ -296,7 +296,7 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
|
||||
op_graph.update_values_with_bounds(node_bounds)
|
||||
|
||||
for i, output_node in op_graph.output_nodes.items():
|
||||
assert expected_output_data_type[i] == output_node.outputs[0].data_type
|
||||
assert expected_output_data_type[i] == output_node.outputs[0].dtype
|
||||
|
||||
|
||||
def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):
|
||||
|
||||
@@ -60,26 +60,26 @@ def test_tensor_value(
|
||||
):
|
||||
"""Test function for TensorValue"""
|
||||
|
||||
tensor_value = tensor_constructor(data_type=data_type, shape=shape)
|
||||
tensor_value = tensor_constructor(dtype=data_type, shape=shape)
|
||||
|
||||
assert expected_is_encrypted == tensor_value.is_encrypted
|
||||
assert expected_shape == tensor_value.shape
|
||||
assert expected_ndim == tensor_value.ndim
|
||||
assert expected_size == tensor_value.size
|
||||
|
||||
assert data_type == tensor_value.data_type
|
||||
assert data_type == tensor_value.dtype
|
||||
|
||||
other_tensor = deepcopy(tensor_value)
|
||||
|
||||
assert other_tensor == tensor_value
|
||||
|
||||
other_tensor_value = deepcopy(other_tensor)
|
||||
other_tensor_value.data_type = DummyDtype()
|
||||
other_tensor_value.dtype = DummyDtype()
|
||||
assert other_tensor_value != tensor_value
|
||||
|
||||
other_shape = tuple(val + 1 for val in shape) if shape is not None else ()
|
||||
other_shape += (2,)
|
||||
other_tensor_value = tensor_constructor(data_type=data_type, shape=other_shape)
|
||||
other_tensor_value = tensor_constructor(dtype=data_type, shape=other_shape)
|
||||
|
||||
assert other_tensor_value.shape != tensor_value.shape
|
||||
assert other_tensor_value.ndim != tensor_value.ndim
|
||||
|
||||
@@ -46,7 +46,7 @@ def test_check_op_graph_is_integer_program():
|
||||
assert len(offending_nodes) == 0
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.output_nodes[0].outputs[0].data_type = Float64
|
||||
op_graph_copy.output_nodes[0].outputs[0].dtype = Float64
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
@@ -55,7 +55,7 @@ def test_check_op_graph_is_integer_program():
|
||||
assert offending_nodes == [op_graph_copy.output_nodes[0]]
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.input_nodes[0].inputs[0].data_type = Float64
|
||||
op_graph_copy.input_nodes[0].inputs[0].dtype = Float64
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
@@ -64,8 +64,8 @@ def test_check_op_graph_is_integer_program():
|
||||
assert offending_nodes == [op_graph_copy.input_nodes[0]]
|
||||
|
||||
op_graph_copy = deepcopy(op_graph)
|
||||
op_graph_copy.input_nodes[0].inputs[0].data_type = Float64
|
||||
op_graph_copy.input_nodes[1].inputs[0].data_type = Float64
|
||||
op_graph_copy.input_nodes[0].inputs[0].dtype = Float64
|
||||
op_graph_copy.input_nodes[1].inputs[0].dtype = Float64
|
||||
|
||||
offending_nodes = []
|
||||
assert not check_op_graph_is_integer_program(op_graph_copy)
|
||||
|
||||
@@ -222,7 +222,7 @@ def test_tracing_astype(
|
||||
|
||||
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
||||
output_node = op_graph.output_nodes[0]
|
||||
assert op_graph_expected_output_type == output_node.outputs[0].data_type
|
||||
assert op_graph_expected_output_type == output_node.outputs[0].dtype
|
||||
|
||||
node_results = op_graph.evaluate({0: numpy.array(input_)})
|
||||
evaluated_output = node_results[output_node]
|
||||
|
||||
Reference in New Issue
Block a user