refactor: rename 'data_type' field of 'BaseValue' to 'dtype'

This commit is contained in:
Umut
2021-09-29 11:32:53 +03:00
parent 77690fed84
commit 36d732b0ae
19 changed files with 79 additions and 81 deletions

View File

@@ -42,7 +42,7 @@ def _check_input_coherency(
base_value = base_value_class(is_encrypted=parameter_base_value.is_encrypted)
if base_value.shape != parameter_base_value.shape or not is_data_type_compatible_with(
base_value.data_type, parameter_base_value.data_type
base_value.dtype, parameter_base_value.dtype
):
warnings.append(
f"expected {str(parameter_base_value)} "

View File

@@ -31,8 +31,8 @@ def ir_nodes_has_integer_input_and_output(node: IntermediateNode) -> bool:
Returns:
bool: True if all input and output values hold Integers
"""
return all(isinstance(x.data_type, Integer) for x in node.inputs) and all(
isinstance(x.data_type, Integer) for x in node.outputs
return all(isinstance(x.dtype, Integer) for x in node.inputs) and all(
isinstance(x.dtype, Integer) for x in node.outputs
)

View File

@@ -47,7 +47,7 @@ def value_is_encrypted_scalar_unsigned_integer(value_to_check: BaseValue) -> boo
"""
return (
value_is_encrypted_scalar_integer(value_to_check)
and not cast(Integer, value_to_check.data_type).is_signed
and not cast(Integer, value_to_check.dtype).is_signed
)
@@ -73,7 +73,7 @@ def value_is_scalar_integer(value_to_check: BaseValue) -> bool:
bool: True if the passed value_to_check is a ScalarValue of type Integer
"""
return isinstance(value_to_check, ScalarValue) and isinstance(
value_to_check.data_type, INTEGER_TYPES
value_to_check.dtype, INTEGER_TYPES
)
@@ -101,7 +101,7 @@ def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> boo
"""
return (
value_is_encrypted_tensor_integer(value_to_check)
and not cast(Integer, value_to_check.data_type).is_signed
and not cast(Integer, value_to_check.dtype).is_signed
)
@@ -127,7 +127,7 @@ def value_is_tensor_integer(value_to_check: BaseValue) -> bool:
bool: True if the passed value_to_check is a TensorValue of type Integer
"""
return isinstance(value_to_check, TensorValue) and isinstance(
value_to_check.data_type, INTEGER_TYPES
value_to_check.dtype, INTEGER_TYPES
)
@@ -216,7 +216,7 @@ def mix_scalar_values_determine_holding_dtype(
isinstance(value2, ScalarValue), f"Unsupported value2: {value2}, expected ScalarValue"
)
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
mixed_value: ScalarValue
if value1.is_encrypted or value2.is_encrypted:
@@ -261,13 +261,13 @@ def mix_tensor_values_determine_holding_dtype(
),
)
holding_type = find_type_to_hold_both_lossy(value1.data_type, value2.data_type)
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
shape = value1.shape
if value1.is_encrypted or value2.is_encrypted:
mixed_value = EncryptedTensor(data_type=holding_type, shape=shape)
mixed_value = EncryptedTensor(dtype=holding_type, shape=shape)
else:
mixed_value = ClearTensor(data_type=holding_type, shape=shape)
mixed_value = ClearTensor(dtype=holding_type, shape=shape)
return mixed_value
@@ -362,10 +362,10 @@ def get_base_value_for_python_constant_data(
assert len(constant_data) > 0
constant_shape = (len(constant_data),)
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
return partial(TensorValue, data_type=constant_data_type, shape=constant_shape)
return partial(TensorValue, dtype=constant_data_type, shape=constant_shape)
constant_data_type = get_base_data_type_for_python_constant_data(constant_data)
return partial(ScalarValue, data_type=constant_data_type)
return partial(ScalarValue, dtype=constant_data_type)
def get_type_constructor_for_python_constant_data(constant_data: Union[int, float]):

View File

@@ -52,7 +52,7 @@ def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.AddEintIntOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result
@@ -63,7 +63,7 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.AddEintOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result
@@ -87,7 +87,7 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.SubIntEintOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result
@@ -116,7 +116,7 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.MulEintIntOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result
@@ -126,7 +126,7 @@ def constant(node, _, __, ctx):
"""Convert a constant inputs."""
if not value_is_clear_scalar_integer(node.outputs[0]):
raise TypeError("Don't support non-integer constants")
dtype = cast(Integer, node.outputs[0].data_type)
dtype = cast(Integer, node.outputs[0].dtype)
if dtype.is_signed:
raise TypeError("Don't support signed constant integer")
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
@@ -145,7 +145,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
x_node = preds[0]
x = ir_to_mlir_node[x_node]
table = node.get_table()
out_dtype = cast(Integer, node.outputs[0].data_type)
out_dtype = cast(Integer, node.outputs[0].dtype)
# Create table
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
tensor_lut = std_dialect.ConstantOp(
@@ -182,7 +182,7 @@ def dot(node, preds, ir_to_mlir_node, ctx):
lhs_node, rhs_node = rhs_node, lhs_node
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.Dot(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result

View File

@@ -101,16 +101,14 @@ class MLIRConverter:
corresponding MLIR type
"""
if value_is_encrypted_scalar_unsigned_integer(value):
return self._get_scalar_integer_type(
cast(Integer, value.data_type).bit_width, True, False
)
return self._get_scalar_integer_type(cast(Integer, value.dtype).bit_width, True, False)
if value_is_clear_scalar_integer(value):
dtype = cast(Integer, value.data_type)
dtype = cast(Integer, value.dtype)
return self._get_scalar_integer_type(
dtype.bit_width, is_encrypted=False, is_signed=dtype.is_signed
)
if value_is_encrypted_tensor_unsigned_integer(value):
dtype = cast(Integer, value.data_type)
dtype = cast(Integer, value.dtype)
return self._get_tensor_type(
dtype.bit_width,
is_encrypted=True,
@@ -118,7 +116,7 @@ class MLIRConverter:
shape=cast(values.TensorValue, value).shape,
)
if value_is_clear_tensor_integer(value):
dtype = cast(Integer, value.data_type)
dtype = cast(Integer, value.dtype)
return self._get_tensor_type(
dtype.bit_width,
is_encrypted=False,

View File

@@ -28,7 +28,7 @@ def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
"""
return all(
all(
value_is_scalar_integer(out) and not cast(Integer, out.data_type).is_signed
value_is_scalar_integer(out) and not cast(Integer, out.dtype).is_signed
for out in out_node.outputs
)
for out_node in op_graph.output_nodes.values()
@@ -45,11 +45,11 @@ def _set_all_bit_width(op_graph: OPGraph, p: int):
for node in op_graph.graph.nodes:
for value in node.outputs + node.inputs:
if value_is_clear_scalar_integer(value) or value_is_clear_tensor_integer(value):
value.data_type.bit_width = p + 1
value.dtype.bit_width = p + 1
elif value_is_encrypted_scalar_integer(value) or value_is_encrypted_tensor_integer(
value
):
value.data_type.bit_width = p
value.dtype.bit_width = p
def update_bit_width_for_mlir(op_graph: OPGraph):
@@ -63,7 +63,7 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
for node in op_graph.graph.nodes:
for value_out in node.outputs:
if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out):
current_node_out_bit_width = value_out.data_type.bit_width - 1
current_node_out_bit_width = value_out.dtype.bit_width - 1
else:
assert_true(
@@ -71,7 +71,7 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
or value_is_encrypted_tensor_integer(value_out)
)
current_node_out_bit_width = value_out.data_type.bit_width
current_node_out_bit_width = value_out.dtype.bit_width
max_bit_width = max(max_bit_width, current_node_out_bit_width)
@@ -106,7 +106,7 @@ def extend_direct_lookup_tables(op_graph: OPGraph):
for node in op_graph.graph.nodes:
if isinstance(node, ArbitraryFunction) and node.op_name == "TLU":
table = node.op_kwargs["table"]
bit_width = cast(Integer, node.inputs[0].data_type).bit_width
bit_width = cast(Integer, node.inputs[0].dtype).bit_width
expected_length = 2 ** bit_width
# TODO: remove no cover once the table length workaround is removed

View File

@@ -196,7 +196,7 @@ class OPGraph:
if not isinstance(node, Input):
for output_value in node.outputs:
if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer):
output_value.data_type = make_integer_to_hold(
output_value.dtype = make_integer_to_hold(
(min_bound, max_bound), force_signed=False
)
else:
@@ -208,8 +208,8 @@ class OPGraph:
f"min_bound: {min_data_type}, max_bound: {max_data_type}"
),
)
output_value.data_type = Float(64)
output_value.data_type.underlying_type_constructor = data_type_constructor
output_value.dtype = Float(64)
output_value.dtype.underlying_type_constructor = data_type_constructor
else:
# Currently variable inputs are only allowed to be integers
custom_assert(
@@ -220,10 +220,10 @@ class OPGraph:
f"max: {max_bound} ({type(max_bound)})"
),
)
node.inputs[0].data_type = make_integer_to_hold(
node.inputs[0].dtype = make_integer_to_hold(
(min_bound, max_bound), force_signed=False
)
node.inputs[0].data_type.underlying_type_constructor = data_type_constructor
node.inputs[0].dtype.underlying_type_constructor = data_type_constructor
node.outputs[0] = deepcopy(node.inputs[0])

View File

@@ -160,7 +160,7 @@ def convert_float_subgraph_to_fused_node(
lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[
terminal_node
],
deepcopy(terminal_node.outputs[0].data_type),
deepcopy(terminal_node.outputs[0].dtype),
op_kwargs={
"float_op_subgraph": float_op_subgraph,
"terminal_node": terminal_node,
@@ -197,13 +197,13 @@ def find_float_subgraph_with_unique_terminal_node(
def is_float_to_single_int_node(node: IntermediateNode) -> bool:
return (
any(isinstance(input_.data_type, Float) for input_ in node.inputs)
any(isinstance(input_.dtype, Float) for input_ in node.inputs)
and len(node.outputs) == 1
and isinstance(node.outputs[0].data_type, Integer)
and isinstance(node.outputs[0].dtype, Integer)
)
def single_int_output_node(node: IntermediateNode) -> bool:
return len(node.outputs) == 1 and isinstance(node.outputs[0].data_type, Integer)
return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer)
float_subgraphs_terminal_nodes = (
node

View File

@@ -242,21 +242,21 @@ class ArbitraryFunction(IntermediateNode):
"""
# Check the input is an unsigned integer to be able to build a table
assert isinstance(
self.inputs[0].data_type, Integer
self.inputs[0].dtype, Integer
), "get_table only works for an unsigned Integer input"
assert not self.inputs[
0
].data_type.is_signed, "get_table only works for an unsigned Integer input"
].dtype.is_signed, "get_table only works for an unsigned Integer input"
type_constructor = self.inputs[0].data_type.underlying_type_constructor
type_constructor = self.inputs[0].dtype.underlying_type_constructor
if type_constructor is None:
logger.info(
f"{self.__class__.__name__} input data type constructor was None, defaulting to int"
)
type_constructor = int
min_input_range = self.inputs[0].data_type.min_value()
max_input_range = self.inputs[0].data_type.max_value() + 1
min_input_range = self.inputs[0].dtype.min_value()
max_input_range = self.inputs[0].dtype.max_value() + 1
table = [
self.evaluate({0: type_constructor(input_value)})

View File

@@ -9,11 +9,11 @@ from ..data_types.base import BaseDataType
class BaseValue(ABC):
"""Abstract base class to represent any kind of value in a program."""
data_type: BaseDataType
dtype: BaseDataType
_is_encrypted: bool
def __init__(self, data_type: BaseDataType, is_encrypted: bool) -> None:
self.data_type = deepcopy(data_type)
def __init__(self, dtype: BaseDataType, is_encrypted: bool) -> None:
self.dtype = deepcopy(dtype)
self._is_encrypted = is_encrypted
def __repr__(self) -> str: # pragma: no cover
@@ -21,7 +21,7 @@ class BaseValue(ABC):
@abstractmethod
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.data_type == other.data_type
return isinstance(other, self.__class__) and self.dtype == other.dtype
@property
def is_encrypted(self) -> bool:

View File

@@ -14,7 +14,7 @@ class ScalarValue(BaseValue):
def __str__(self) -> str: # pragma: no cover
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
return f"{encrypted_str}Scalar<{self.data_type!r}>"
return f"{encrypted_str}Scalar<{self.dtype!r}>"
@property
def shape(self) -> Tuple[int, ...]:
@@ -26,28 +26,28 @@ class ScalarValue(BaseValue):
return ()
def make_clear_scalar(data_type: BaseDataType) -> ScalarValue:
def make_clear_scalar(dtype: BaseDataType) -> ScalarValue:
"""Create a clear ScalarValue.
Args:
data_type (BaseDataType): The data type for the value.
dtype (BaseDataType): The data type for the value.
Returns:
ScalarValue: The corresponding ScalarValue.
"""
return ScalarValue(data_type=data_type, is_encrypted=False)
return ScalarValue(dtype=dtype, is_encrypted=False)
def make_encrypted_scalar(data_type: BaseDataType) -> ScalarValue:
def make_encrypted_scalar(dtype: BaseDataType) -> ScalarValue:
"""Create an encrypted ScalarValue.
Args:
data_type (BaseDataType): The data type for the value.
dtype (BaseDataType): The data type for the value.
Returns:
ScalarValue: The corresponding ScalarValue.
"""
return ScalarValue(data_type=data_type, is_encrypted=True)
return ScalarValue(dtype=dtype, is_encrypted=True)
ClearScalar = make_clear_scalar

View File

@@ -16,11 +16,11 @@ class TensorValue(BaseValue):
def __init__(
self,
data_type: BaseDataType,
dtype: BaseDataType,
is_encrypted: bool,
shape: Optional[Tuple[int, ...]] = None,
) -> None:
super().__init__(data_type, is_encrypted)
super().__init__(dtype, is_encrypted)
# Managing tensors as in numpy, no shape or () is treated as a 0-D array of size 1
self._shape = shape if shape is not None else ()
self._ndim = len(self._shape)
@@ -37,7 +37,7 @@ class TensorValue(BaseValue):
def __str__(self) -> str:
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
return f"{encrypted_str}Tensor<{str(self.data_type)}, shape={self.shape}>"
return f"{encrypted_str}Tensor<{str(self.dtype)}, shape={self.shape}>"
@property
def shape(self) -> Tuple[int, ...]:
@@ -68,35 +68,35 @@ class TensorValue(BaseValue):
def make_clear_tensor(
data_type: BaseDataType,
dtype: BaseDataType,
shape: Optional[Tuple[int, ...]] = None,
) -> TensorValue:
"""Create a clear TensorValue.
Args:
data_type (BaseDataType): The data type for the tensor.
dtype (BaseDataType): The data type for the tensor.
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(data_type=data_type, is_encrypted=False, shape=shape)
return TensorValue(dtype=dtype, is_encrypted=False, shape=shape)
def make_encrypted_tensor(
data_type: BaseDataType,
dtype: BaseDataType,
shape: Optional[Tuple[int, ...]] = None,
) -> TensorValue:
"""Create an encrypted TensorValue.
Args:
data_type (BaseDataType): The data type for the tensor.
dtype (BaseDataType): The data type for the tensor.
shape (Optional[Tuple[int, ...]], optional): The tensor shape. Defaults to None.
Returns:
TensorValue: The corresponding TensorValue.
"""
return TensorValue(data_type=data_type, is_encrypted=True, shape=shape)
return TensorValue(dtype=dtype, is_encrypted=True, shape=shape)
ClearTensor = make_clear_tensor

View File

@@ -125,9 +125,9 @@ def _compile_numpy_function_into_op_graph_internal(
# this loop will determine the number of possible inputs of the function
# if a function have a single 3-bit input, for example, `inputset_size_upper_limit` will be 8
for parameter_value in function_parameters.values():
if isinstance(parameter_value.data_type, Integer):
if isinstance(parameter_value.dtype, Integer):
# multiple parameter bit-widths are multiplied as they can be combined into an input
inputset_size_upper_limit *= 2 ** parameter_value.data_type.bit_width
inputset_size_upper_limit *= 2 ** parameter_value.dtype.bit_width
# if the upper limit of the inputset size goes above 10,
# break the loop as we will require at least 10 inputs in this case

View File

@@ -172,9 +172,9 @@ def get_base_value_for_numpy_or_python_constant_data(
base_dtype = get_base_data_type_for_numpy_or_python_constant_data(constant_data)
if isinstance(constant_data, numpy.ndarray):
constant_data_value = partial(TensorValue, data_type=base_dtype, shape=constant_data.shape)
constant_data_value = partial(TensorValue, dtype=base_dtype, shape=constant_data.shape)
elif isinstance(constant_data, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES):
constant_data_value = partial(ScalarValue, data_type=base_dtype)
constant_data_value = partial(ScalarValue, dtype=base_dtype)
else:
constant_data_value = get_base_value_for_python_constant_data(constant_data)
return constant_data_value

View File

@@ -129,7 +129,7 @@ class NPTracer(BaseTracer):
@staticmethod
def _manage_dtypes(ufunc: Union[numpy.ufunc, Callable], *input_tracers: BaseTracer):
output_dtypes = get_numpy_function_output_dtype(
ufunc, [input_tracer.output.data_type for input_tracer in input_tracers]
ufunc, [input_tracer.output.dtype for input_tracer in input_tracers]
)
common_output_dtypes = [
convert_numpy_dtype_to_base_data_type(dtype) for dtype in output_dtypes

View File

@@ -296,7 +296,7 @@ def test_eval_op_graph_bounds_on_inputset_multiple_output(
op_graph.update_values_with_bounds(node_bounds)
for i, output_node in op_graph.output_nodes.items():
assert expected_output_data_type[i] == output_node.outputs[0].data_type
assert expected_output_data_type[i] == output_node.outputs[0].dtype
def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):

View File

@@ -60,26 +60,26 @@ def test_tensor_value(
):
"""Test function for TensorValue"""
tensor_value = tensor_constructor(data_type=data_type, shape=shape)
tensor_value = tensor_constructor(dtype=data_type, shape=shape)
assert expected_is_encrypted == tensor_value.is_encrypted
assert expected_shape == tensor_value.shape
assert expected_ndim == tensor_value.ndim
assert expected_size == tensor_value.size
assert data_type == tensor_value.data_type
assert data_type == tensor_value.dtype
other_tensor = deepcopy(tensor_value)
assert other_tensor == tensor_value
other_tensor_value = deepcopy(other_tensor)
other_tensor_value.data_type = DummyDtype()
other_tensor_value.dtype = DummyDtype()
assert other_tensor_value != tensor_value
other_shape = tuple(val + 1 for val in shape) if shape is not None else ()
other_shape += (2,)
other_tensor_value = tensor_constructor(data_type=data_type, shape=other_shape)
other_tensor_value = tensor_constructor(dtype=data_type, shape=other_shape)
assert other_tensor_value.shape != tensor_value.shape
assert other_tensor_value.ndim != tensor_value.ndim

View File

@@ -46,7 +46,7 @@ def test_check_op_graph_is_integer_program():
assert len(offending_nodes) == 0
op_graph_copy = deepcopy(op_graph)
op_graph_copy.output_nodes[0].outputs[0].data_type = Float64
op_graph_copy.output_nodes[0].outputs[0].dtype = Float64
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
@@ -55,7 +55,7 @@ def test_check_op_graph_is_integer_program():
assert offending_nodes == [op_graph_copy.output_nodes[0]]
op_graph_copy = deepcopy(op_graph)
op_graph_copy.input_nodes[0].inputs[0].data_type = Float64
op_graph_copy.input_nodes[0].inputs[0].dtype = Float64
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)
@@ -64,8 +64,8 @@ def test_check_op_graph_is_integer_program():
assert offending_nodes == [op_graph_copy.input_nodes[0]]
op_graph_copy = deepcopy(op_graph)
op_graph_copy.input_nodes[0].inputs[0].data_type = Float64
op_graph_copy.input_nodes[1].inputs[0].data_type = Float64
op_graph_copy.input_nodes[0].inputs[0].dtype = Float64
op_graph_copy.input_nodes[1].inputs[0].dtype = Float64
offending_nodes = []
assert not check_op_graph_is_integer_program(op_graph_copy)

View File

@@ -222,7 +222,7 @@ def test_tracing_astype(
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
output_node = op_graph.output_nodes[0]
assert op_graph_expected_output_type == output_node.outputs[0].data_type
assert op_graph_expected_output_type == output_node.outputs[0].dtype
node_results = op_graph.evaluate({0: numpy.array(input_)})
evaluated_output = node_results[output_node]