mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: implement bitwise and comparison operators
This commit is contained in:
@@ -96,6 +96,21 @@ class GraphConverter:
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only assignment to encrypted tensors are supported"
|
||||
|
||||
elif name in ["bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift"]:
|
||||
assert_that(len(inputs) == 2)
|
||||
if all(value.is_encrypted for value in node.inputs):
|
||||
pred_nodes = graph.ordered_preds_of(node)
|
||||
if (
|
||||
name in ["left_shift", "right_shift"]
|
||||
and cast(Integer, pred_nodes[1].output.dtype).bit_width > 4
|
||||
):
|
||||
return "only up to 4-bit shifts are supported"
|
||||
|
||||
for pred_node in pred_nodes:
|
||||
assert isinstance(pred_node.output.dtype, Integer)
|
||||
if pred_node.output.dtype.is_signed:
|
||||
return "only unsigned bitwise operations are supported"
|
||||
|
||||
elif name == "broadcast_to":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
@@ -115,6 +130,9 @@ class GraphConverter:
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only dot product between encrypted and clear is supported"
|
||||
|
||||
elif name in ["equal", "greater", "greater_equal", "less", "less_equal", "not_equal"]:
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
elif name == "expand_dims":
|
||||
assert_that(len(inputs) == 1)
|
||||
|
||||
@@ -251,6 +269,20 @@ class GraphConverter:
|
||||
current_node_bit_width = (
|
||||
dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width
|
||||
)
|
||||
if (
|
||||
all(value.is_encrypted for value in node.inputs)
|
||||
and node.operation == Operation.Generic
|
||||
and node.properties["name"]
|
||||
in [
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"less",
|
||||
"less_equal",
|
||||
]
|
||||
):
|
||||
# implementation of these operators require at least 4 bits
|
||||
current_node_bit_width = max(current_node_bit_width, 4)
|
||||
|
||||
if max_bit_width < current_node_bit_width:
|
||||
max_bit_width = current_node_bit_width
|
||||
max_bit_width_node = node
|
||||
@@ -286,7 +318,10 @@ class GraphConverter:
|
||||
+ graph.format(highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
for node in graph.graph.nodes:
|
||||
for node in nx.topological_sort(graph.graph):
|
||||
assert isinstance(node.output.dtype, Integer)
|
||||
node.properties["original_bit_width"] = node.output.dtype.bit_width
|
||||
|
||||
for value in node.inputs + [node.output]:
|
||||
dtype = value.dtype
|
||||
assert_that(isinstance(dtype, Integer))
|
||||
@@ -335,9 +370,14 @@ class GraphConverter:
|
||||
variable_input_bit_width = variable_input_dtype.bit_width
|
||||
offset_constant_dtype = SignedInteger(variable_input_bit_width + 1)
|
||||
|
||||
offset_constant = Node.constant(abs(variable_input_dtype.min()))
|
||||
offset_constant_value = abs(variable_input_dtype.min())
|
||||
|
||||
offset_constant = Node.constant(offset_constant_value)
|
||||
offset_constant.output.dtype = offset_constant_dtype
|
||||
|
||||
original_bit_width = Integer.that_can_represent(offset_constant_value).bit_width
|
||||
offset_constant.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
add_offset = Node.generic(
|
||||
"add",
|
||||
[variable_input_value, ClearScalar(offset_constant_dtype)],
|
||||
@@ -345,6 +385,9 @@ class GraphConverter:
|
||||
np.add,
|
||||
)
|
||||
|
||||
original_bit_width = variable_input_node.properties["original_bit_width"]
|
||||
add_offset.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
nx_graph.remove_edge(variable_input_node, node)
|
||||
|
||||
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0)
|
||||
@@ -412,6 +455,10 @@ class GraphConverter:
|
||||
kwargs={"newshape": required_value_shape},
|
||||
)
|
||||
|
||||
modified_pred.properties["original_bit_width"] = pred_to_modify.properties[
|
||||
"original_bit_width"
|
||||
]
|
||||
|
||||
nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0)
|
||||
|
||||
nx_graph.remove_edge(pred_to_modify, node)
|
||||
@@ -448,6 +495,9 @@ class GraphConverter:
|
||||
lambda: np.zeros((), dtype=np.int64),
|
||||
)
|
||||
|
||||
original_bit_width = 1
|
||||
zero.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
new_assigned_pred = Node.generic(
|
||||
"add",
|
||||
[assigned_pred.output, zero.output],
|
||||
@@ -455,6 +505,9 @@ class GraphConverter:
|
||||
np.add,
|
||||
)
|
||||
|
||||
original_bit_width = assigned_pred.properties["original_bit_width"]
|
||||
new_assigned_pred.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
nx_graph.remove_edge(preds[1], node)
|
||||
|
||||
nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0)
|
||||
@@ -473,7 +526,24 @@ class GraphConverter:
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
OPS_TO_TENSORIZE = ["add", "broadcast_to", "dot", "multiply", "subtract"]
|
||||
OPS_TO_TENSORIZE = [
|
||||
"add",
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
"bitwise_xor",
|
||||
"broadcast_to",
|
||||
"dot",
|
||||
"equal",
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"left_shift",
|
||||
"less",
|
||||
"less_equal",
|
||||
"multiply",
|
||||
"not_equal",
|
||||
"right_shift",
|
||||
"subtract",
|
||||
]
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
tensorized_scalars: Dict[Node, Node] = {}
|
||||
@@ -490,6 +560,11 @@ class GraphConverter:
|
||||
if not node.inputs[0].is_scalar:
|
||||
continue
|
||||
|
||||
# for bitwise and comparison operators that can have constants
|
||||
# we don't need broadcasting here
|
||||
if node.converted_to_table_lookup:
|
||||
continue
|
||||
|
||||
pred_to_tensorize: Optional[Node] = None
|
||||
pred_to_tensorize_index = 0
|
||||
|
||||
@@ -513,8 +588,14 @@ class GraphConverter:
|
||||
tensorized_value,
|
||||
lambda *args: np.array(args),
|
||||
)
|
||||
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
|
||||
|
||||
original_bit_width = pred_to_tensorize.properties["original_bit_width"]
|
||||
tensorized_pred.properties["original_bit_width"] = original_bit_width
|
||||
|
||||
original_shape = ()
|
||||
tensorized_pred.properties["original_shape"] = original_shape
|
||||
|
||||
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
|
||||
tensorized_scalars[pred_to_tensorize] = tensorized_pred
|
||||
|
||||
assert tensorized_pred is not None
|
||||
@@ -522,6 +603,10 @@ class GraphConverter:
|
||||
nx_graph.remove_edge(pred_to_tensorize, node)
|
||||
nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index)
|
||||
|
||||
new_input_value = deepcopy(node.inputs[pred_to_tensorize_index])
|
||||
new_input_value.shape = (1,)
|
||||
node.inputs[pred_to_tensorize_index] = new_input_value
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,9 @@ Declaration of `NodeConverter` class.
|
||||
|
||||
# pylint: disable=no-member,no-name-in-module,too-many-lines
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from typing import Callable, Dict, List, Set, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
from concrete.lang.dialects import fhe, fhelinalg
|
||||
@@ -34,6 +36,20 @@ from .utils import construct_deduplicated_tables
|
||||
# pylint: enable=no-member,no-name-in-module
|
||||
|
||||
|
||||
class Comparison(IntEnum):
|
||||
"""
|
||||
Comparison enum, to generalize conversion of comparison operators.
|
||||
|
||||
Because comparison result has 3 possibilities, Comparison enum is 2 bits.
|
||||
"""
|
||||
|
||||
EQUAL = 0b00
|
||||
LESS = 0b01
|
||||
GREATER = 0b10
|
||||
|
||||
UNUSED = 0b11
|
||||
|
||||
|
||||
class NodeConverter:
|
||||
"""
|
||||
NodeConverter class, to convert computation graph nodes to their MLIR equivalent.
|
||||
@@ -125,11 +141,12 @@ class NodeConverter:
|
||||
self.all_of_the_inputs_are_tensors = True
|
||||
self.one_of_the_inputs_is_a_tensor = False
|
||||
|
||||
for inp in node.inputs:
|
||||
if not inp.is_encrypted:
|
||||
for pred in graph.ordered_preds_of(node):
|
||||
if not pred.output.is_encrypted:
|
||||
self.all_of_the_inputs_are_encrypted = False
|
||||
|
||||
if inp.is_scalar:
|
||||
shape = pred.properties.get("original_shape", pred.output.shape)
|
||||
if shape == ():
|
||||
self.all_of_the_inputs_are_tensors = False
|
||||
else:
|
||||
self.one_of_the_inputs_is_a_tensor = True
|
||||
@@ -156,20 +173,31 @@ class NodeConverter:
|
||||
"add": self._convert_add,
|
||||
"array": self._convert_array,
|
||||
"assign.static": self._convert_static_assignment,
|
||||
"bitwise_and": self._convert_bitwise_and,
|
||||
"bitwise_or": self._convert_bitwise_or,
|
||||
"bitwise_xor": self._convert_bitwise_xor,
|
||||
"broadcast_to": self._convert_broadcast_to,
|
||||
"concatenate": self._convert_concat,
|
||||
"conv1d": self._convert_conv1d,
|
||||
"conv2d": self._convert_conv2d,
|
||||
"conv3d": self._convert_conv3d,
|
||||
"dot": self._convert_dot,
|
||||
"equal": self._convert_equal,
|
||||
"expand_dims": self._convert_reshape,
|
||||
"greater": self._convert_greater,
|
||||
"greater_equal": self._convert_greater_equal,
|
||||
"index.static": self._convert_static_indexing,
|
||||
"left_shift": self._convert_left_shift,
|
||||
"less": self._convert_less,
|
||||
"less_equal": self._convert_less_equal,
|
||||
"matmul": self._convert_matmul,
|
||||
"maxpool": self._convert_maxpool,
|
||||
"multiply": self._convert_mul,
|
||||
"negative": self._convert_neg,
|
||||
"not_equal": self._convert_not_equal,
|
||||
"ones": self._convert_ones,
|
||||
"reshape": self._convert_reshape,
|
||||
"right_shift": self._convert_right_shift,
|
||||
"squeeze": self._convert_squeeze,
|
||||
"subtract": self._convert_sub,
|
||||
"sum": self._convert_sum,
|
||||
@@ -183,7 +211,7 @@ class NodeConverter:
|
||||
assert_that(self.node.converted_to_table_lookup)
|
||||
return self._convert_tlu()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
# pylint: disable=no-self-use,too-many-branches,too-many-locals,too-many-statements
|
||||
|
||||
def _convert_add(self) -> OpResult:
|
||||
"""
|
||||
@@ -277,6 +305,39 @@ class NodeConverter:
|
||||
self.from_elements_operations[placeholder_result] = processed_preds
|
||||
return placeholder_result
|
||||
|
||||
def _convert_bitwise_and(self) -> OpResult:
|
||||
"""
|
||||
Convert "bitwise_and" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_bitwise(lambda x, y: x & y)
|
||||
|
||||
def _convert_bitwise_or(self) -> OpResult:
|
||||
"""
|
||||
Convert "bitwise_or" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_bitwise(lambda x, y: x | y)
|
||||
|
||||
def _convert_bitwise_xor(self) -> OpResult:
|
||||
"""
|
||||
Convert "bitwise_xor" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_bitwise(lambda x, y: x ^ y)
|
||||
|
||||
def _convert_concat(self) -> OpResult:
|
||||
"""
|
||||
Convert "concatenate" node to its corresponding MLIR representation.
|
||||
@@ -462,6 +523,76 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def _convert_equal(self) -> OpResult:
|
||||
"""
|
||||
Convert "equal" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_equality(equals=True)
|
||||
|
||||
def _convert_greater(self) -> OpResult:
|
||||
"""
|
||||
Convert "greater" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_compare(invert_operands=True, accept={Comparison.LESS})
|
||||
|
||||
def _convert_greater_equal(self) -> OpResult:
|
||||
"""
|
||||
Convert "greater_equal" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_compare(
|
||||
invert_operands=True, accept={Comparison.LESS, Comparison.EQUAL}
|
||||
)
|
||||
|
||||
def _convert_left_shift(self) -> OpResult:
|
||||
"""
|
||||
Convert "left_shift" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_shift(orientation="left")
|
||||
|
||||
def _convert_less(self) -> OpResult:
|
||||
"""
|
||||
Convert "less" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_compare(invert_operands=False, accept={Comparison.LESS})
|
||||
|
||||
def _convert_less_equal(self) -> OpResult:
|
||||
"""
|
||||
Convert "less_equal" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_compare(
|
||||
invert_operands=False, accept={Comparison.LESS, Comparison.EQUAL}
|
||||
)
|
||||
|
||||
def _convert_matmul(self) -> OpResult:
|
||||
"""Convert a MatMul node to its corresponding MLIR representation.
|
||||
|
||||
@@ -537,6 +668,17 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def _convert_not_equal(self) -> OpResult:
|
||||
"""
|
||||
Convert "not_equal" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_equality(equals=False)
|
||||
|
||||
def _convert_ones(self) -> OpResult:
|
||||
"""
|
||||
Convert "ones" node to its corresponding MLIR representation.
|
||||
@@ -699,6 +841,17 @@ class NodeConverter:
|
||||
),
|
||||
).result
|
||||
|
||||
def _convert_right_shift(self) -> OpResult:
|
||||
"""
|
||||
Convert "right_shift" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
return self._convert_shift(orientation="right")
|
||||
|
||||
def _convert_static_assignment(self) -> OpResult:
|
||||
"""
|
||||
Convert "assign.static" node to its corresponding MLIR representation.
|
||||
@@ -1094,7 +1247,19 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def _create_constant(self, mlir_type: Type, mlir_attribute: Attribute):
|
||||
def _create_add(self, resulting_type, a, b) -> OpResult:
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
return fhelinalg.AddEintOp(resulting_type, a, b).result
|
||||
|
||||
return fhe.AddEintOp(resulting_type, a, b).result
|
||||
|
||||
def _create_add_clear(self, resulting_type, a, b) -> OpResult:
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
return fhelinalg.AddEintIntOp(resulting_type, a, b).result
|
||||
|
||||
return fhe.AddEintIntOp(resulting_type, a, b).result
|
||||
|
||||
def _create_constant(self, mlir_type: Type, mlir_attribute: Attribute) -> OpResult:
|
||||
result = self.constant_cache.get((mlir_type, mlir_attribute))
|
||||
if result is None:
|
||||
# ConstantOp is being decorated, and the init function is supposed to take more
|
||||
@@ -1105,4 +1270,597 @@ class NodeConverter:
|
||||
self.constant_cache[(mlir_type, mlir_attribute)] = result
|
||||
return result
|
||||
|
||||
# pylint: enable=no-self-use
|
||||
def _create_constant_integer(self, bit_width, x) -> OpResult:
|
||||
constant_value = Value(
|
||||
Integer(is_signed=True, bit_width=bit_width + 1),
|
||||
shape=(1,) if self.one_of_the_inputs_is_a_tensor else (),
|
||||
is_encrypted=False,
|
||||
)
|
||||
constant_type = NodeConverter.value_to_mlir_type(self.ctx, constant_value)
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
return self._create_constant(
|
||||
constant_type,
|
||||
Attribute.parse(f"dense<{[int(x)]}> : {constant_type}"),
|
||||
)
|
||||
|
||||
return self._create_constant(constant_type, IntegerAttr.get(constant_type, x))
|
||||
|
||||
def _create_mul_clear(self, resulting_type, a, b) -> OpResult:
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
return fhelinalg.MulEintIntOp(resulting_type, a, b).result
|
||||
|
||||
return fhe.MulEintIntOp(resulting_type, a, b).result
|
||||
|
||||
def _create_sub(self, resulting_type, a, b) -> OpResult:
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
return fhelinalg.SubEintOp(resulting_type, a, b).result
|
||||
|
||||
return fhe.SubEintOp(resulting_type, a, b).result
|
||||
|
||||
def _create_tlu(self, resulting_type, pred, lut_values) -> OpResult:
|
||||
resulting_type_str = str(resulting_type)
|
||||
bit_width_search = re.search(r"FHE\.eint<([0-9]+)>", resulting_type_str)
|
||||
|
||||
assert bit_width_search is not None
|
||||
bit_width_str = bit_width_search.group(1)
|
||||
|
||||
bit_width = int(bit_width_str)
|
||||
lut_values += [0] * ((2**bit_width) - len(lut_values))
|
||||
|
||||
lut_element_type = IntegerType.get_signless(64, context=self.ctx)
|
||||
lut_type = RankedTensorType.get((len(lut_values),), lut_element_type)
|
||||
lut_attr = Attribute.parse(f"dense<{str(lut_values)}> : {lut_type}")
|
||||
lut = self._create_constant(resulting_type, lut_attr).result
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
return fhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
|
||||
return fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
|
||||
def _convert_bitwise(self, operation: Callable) -> OpResult:
|
||||
"""
|
||||
Convert `bitwise_{and,or,xor}` node to their corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
if not self.all_of_the_inputs_are_encrypted:
|
||||
return self._convert_tlu()
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
x = self.preds[0]
|
||||
y = self.preds[1]
|
||||
|
||||
x_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.inputs[0])
|
||||
y_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.inputs[1])
|
||||
|
||||
assert isinstance(self.node.output.dtype, Integer)
|
||||
bit_width = self.node.output.dtype.bit_width
|
||||
|
||||
chunk_size = int(np.floor(bit_width / 2))
|
||||
mask = (2**chunk_size) - 1
|
||||
|
||||
original_bit_width = max(
|
||||
pred_node.properties["original_bit_width"]
|
||||
for pred_node in self.graph.ordered_preds_of(self.node)
|
||||
)
|
||||
|
||||
chunks = []
|
||||
for offset in range(0, original_bit_width, chunk_size):
|
||||
x_lut = [((x >> offset) & mask) << chunk_size for x in range(2**bit_width)]
|
||||
y_lut = [(y >> offset) & mask for y in range(2**bit_width)]
|
||||
|
||||
x_chunk = self._create_tlu(x_type, x, x_lut)
|
||||
y_chunk = self._create_tlu(y_type, y, y_lut)
|
||||
|
||||
packed_x_and_y_chunks = self._create_add(resulting_type, x_chunk, y_chunk)
|
||||
result_chunk = self._create_tlu(
|
||||
resulting_type,
|
||||
packed_x_and_y_chunks,
|
||||
[
|
||||
operation(x, y) << offset
|
||||
for x in range(2**chunk_size)
|
||||
for y in range(2**chunk_size)
|
||||
],
|
||||
)
|
||||
|
||||
chunks.append(result_chunk)
|
||||
|
||||
# add all chunks together in a tree to maximize dataflow parallelization
|
||||
|
||||
# c1 c2 c3 c4
|
||||
# \/ \/
|
||||
# s2 s1
|
||||
# \ /
|
||||
# \ /
|
||||
# s3
|
||||
|
||||
while len(chunks) > 1:
|
||||
a = chunks.pop()
|
||||
b = chunks.pop()
|
||||
|
||||
result = self._create_add(resulting_type, a, b)
|
||||
chunks.insert(0, result)
|
||||
|
||||
return chunks[0]
|
||||
|
||||
def _convert_compare(self, accept: Set[Comparison], invert_operands: bool = False) -> OpResult:
|
||||
if not self.all_of_the_inputs_are_encrypted:
|
||||
return self._convert_tlu()
|
||||
|
||||
inputs = self.node.inputs
|
||||
preds = self.preds
|
||||
|
||||
if invert_operands:
|
||||
inputs = inputs[::-1]
|
||||
preds = preds[::-1]
|
||||
|
||||
x = preds[0]
|
||||
y = preds[1]
|
||||
|
||||
x_is_signed = cast(Integer, inputs[0].dtype).is_signed
|
||||
y_is_signed = cast(Integer, inputs[1].dtype).is_signed
|
||||
|
||||
x_is_unsigned = not x_is_signed
|
||||
y_is_unsigned = not y_is_signed
|
||||
|
||||
pred_nodes = self.graph.ordered_preds_of(self.node)
|
||||
if invert_operands:
|
||||
pred_nodes = list(reversed(pred_nodes))
|
||||
|
||||
x_bit_width, y_bit_width = [node.properties["original_bit_width"] for node in pred_nodes]
|
||||
comparison_bit_width = max(x_bit_width, y_bit_width)
|
||||
|
||||
x_dtype = Integer(is_signed=x_is_signed, bit_width=x_bit_width)
|
||||
y_dtype = Integer(is_signed=y_is_signed, bit_width=y_bit_width)
|
||||
|
||||
x_minus_y_min = x_dtype.min() - y_dtype.max()
|
||||
x_minus_y_max = x_dtype.max() - y_dtype.min()
|
||||
|
||||
x_minus_y_range = [x_minus_y_min, x_minus_y_max]
|
||||
x_minus_y_dtype = Integer.that_can_represent(x_minus_y_range)
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
bit_width = cast(Integer, self.node.output.dtype).bit_width
|
||||
signed_offset = 2 ** (bit_width - 1)
|
||||
|
||||
sanitizer = self._create_constant_integer(bit_width, signed_offset)
|
||||
if x_minus_y_dtype.bit_width <= bit_width:
|
||||
x_minus_y = self._create_sub(resulting_type, x, y)
|
||||
sanitized_x_minus_y = self._create_add_clear(resulting_type, x_minus_y, sanitizer)
|
||||
|
||||
accept_less = int(Comparison.LESS in accept)
|
||||
accept_equal = int(Comparison.EQUAL in accept)
|
||||
accept_greater = int(Comparison.GREATER in accept)
|
||||
|
||||
all_cells = 2**bit_width
|
||||
|
||||
less_cells = 2 ** (bit_width - 1)
|
||||
equal_cells = 1
|
||||
greater_cells = less_cells - equal_cells
|
||||
|
||||
assert less_cells + equal_cells + greater_cells == all_cells
|
||||
|
||||
table = (
|
||||
[accept_less] * less_cells
|
||||
+ [accept_equal] * equal_cells
|
||||
+ [accept_greater] * greater_cells
|
||||
)
|
||||
return self._create_tlu(resulting_type, sanitized_x_minus_y, table)
|
||||
|
||||
# Comparison between signed and unsigned is tricky.
|
||||
# To deal with them, we add -min of the signed number to both operands
|
||||
# such that they are both positive. To avoid overflowing
|
||||
# the unsigned operand this addition is done "virtually"
|
||||
# while constructing one of the luts.
|
||||
|
||||
# A flag ("is_unsigned_greater_than_half") is emitted in MLIR to keep track
|
||||
# if the unsigned operand was greater than the max signed number as it
|
||||
# is needed to determine the result of the comparison.
|
||||
|
||||
# Exemple: to compare x and y where x is an int3 and y and uint3, when y
|
||||
# is greater than 4 we are sure than x will be less than x.
|
||||
|
||||
if inputs[0].shape != self.node.output.shape:
|
||||
x = self._create_add(resulting_type, x, self._convert_zeros())
|
||||
if inputs[1].shape != self.node.output.shape:
|
||||
y = self._create_add(resulting_type, y, self._convert_zeros())
|
||||
|
||||
offset_x_by = 0
|
||||
offset_y_by = 0
|
||||
|
||||
if x_is_signed or y_is_signed:
|
||||
if x_is_signed:
|
||||
x = self._create_add_clear(resulting_type, x, sanitizer)
|
||||
else:
|
||||
offset_x_by = signed_offset
|
||||
|
||||
if y_is_signed:
|
||||
y = self._create_add_clear(resulting_type, y, sanitizer)
|
||||
else:
|
||||
offset_y_by = signed_offset
|
||||
|
||||
def compare(x, y):
|
||||
if x < y:
|
||||
return Comparison.LESS
|
||||
|
||||
if x > y:
|
||||
return Comparison.GREATER
|
||||
|
||||
return Comparison.EQUAL
|
||||
|
||||
chunk_size = int(np.floor(bit_width / 2))
|
||||
carries = self._pack_to_chunk_groups_and_map(
|
||||
resulting_type,
|
||||
comparison_bit_width,
|
||||
chunk_size,
|
||||
x,
|
||||
y,
|
||||
lambda i, x, y: compare(x, y) << (min(i, 1) * 2),
|
||||
x_offset=offset_x_by,
|
||||
y_offset=offset_y_by,
|
||||
)
|
||||
|
||||
# This is the reduction step -- we have an array where the entry i is the
|
||||
# result of comparing the chunks of x and y at position i.
|
||||
|
||||
all_comparisons = [Comparison.EQUAL, Comparison.LESS, Comparison.GREATER, Comparison.UNUSED]
|
||||
pick_first_not_equal_lut = [
|
||||
int(
|
||||
current_comparison
|
||||
if previous_comparison == Comparison.EQUAL
|
||||
else previous_comparison
|
||||
)
|
||||
for current_comparison in all_comparisons
|
||||
for previous_comparison in all_comparisons
|
||||
]
|
||||
|
||||
carry = carries[0]
|
||||
for next_carry in carries[1:]:
|
||||
combined_carries = self._create_add(resulting_type, next_carry, carry)
|
||||
carry = self._create_tlu(resulting_type, combined_carries, pick_first_not_equal_lut)
|
||||
|
||||
if x_is_signed != y_is_signed:
|
||||
carry_bit_width = 2
|
||||
is_less_mask = int(Comparison.LESS)
|
||||
|
||||
is_unsigned_greater_than_half = self._create_tlu(
|
||||
resulting_type,
|
||||
x if x_is_unsigned else y,
|
||||
[int(value >= signed_offset) << carry_bit_width for value in range(2**bit_width)],
|
||||
)
|
||||
packed_carry_and_is_unsigned_greater_than_half = self._create_add(
|
||||
resulting_type,
|
||||
is_unsigned_greater_than_half,
|
||||
carry,
|
||||
)
|
||||
|
||||
# this function is actually converting either
|
||||
# - lhs < rhs
|
||||
# - lhs <= rhs
|
||||
|
||||
# in the implementation, we call
|
||||
# - x = lhs
|
||||
# - y = rhs
|
||||
|
||||
# so if y is unsigned and greater than half
|
||||
# - y is definitely bigger than x
|
||||
# - is_unsigned_greater_than_half == 1
|
||||
# - result == (lhs < rhs) == (x < y) == 1
|
||||
|
||||
# so if x is unsigned and greater than half
|
||||
# - x is definitely bigger than y
|
||||
# - is_unsigned_greater_than_half == 1
|
||||
# - result == (lhs < rhs) == (x < y) == 0
|
||||
|
||||
if y_is_unsigned:
|
||||
result_table = [
|
||||
1 if (i >> carry_bit_width) else (i & is_less_mask) for i in range(2**3)
|
||||
]
|
||||
else:
|
||||
result_table = [
|
||||
0 if (i >> carry_bit_width) else (i & is_less_mask) for i in range(2**3)
|
||||
]
|
||||
|
||||
result = self._create_tlu(
|
||||
resulting_type,
|
||||
packed_carry_and_is_unsigned_greater_than_half,
|
||||
result_table,
|
||||
)
|
||||
else:
|
||||
boolean_result_lut = [int(comparison in accept) for comparison in all_comparisons]
|
||||
result = self._create_tlu(resulting_type, carry, boolean_result_lut)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_equality(self, equals: bool) -> OpResult:
|
||||
if not self.all_of_the_inputs_are_encrypted:
|
||||
return self._convert_tlu()
|
||||
|
||||
x = self.preds[0]
|
||||
y = self.preds[1]
|
||||
|
||||
x_is_signed = cast(Integer, self.node.inputs[0].dtype).is_signed
|
||||
y_is_signed = cast(Integer, self.node.inputs[1].dtype).is_signed
|
||||
|
||||
x_is_unsigned = not x_is_signed
|
||||
y_is_unsigned = not y_is_signed
|
||||
|
||||
x_bit_width, y_bit_width = [
|
||||
node.properties["original_bit_width"] for node in self.graph.ordered_preds_of(self.node)
|
||||
]
|
||||
comparison_bit_width = max(x_bit_width, y_bit_width)
|
||||
|
||||
x_dtype = Integer(is_signed=x_is_signed, bit_width=x_bit_width)
|
||||
y_dtype = Integer(is_signed=y_is_signed, bit_width=y_bit_width)
|
||||
|
||||
x_minus_y_min = x_dtype.min() - y_dtype.max()
|
||||
x_minus_y_max = x_dtype.max() - y_dtype.min()
|
||||
|
||||
x_minus_y_range = [x_minus_y_min, x_minus_y_max]
|
||||
x_minus_y_dtype = Integer.that_can_represent(x_minus_y_range)
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
bit_width = cast(Integer, self.node.output.dtype).bit_width
|
||||
signed_offset = 2 ** (bit_width - 1)
|
||||
|
||||
sanitizer = self._create_constant_integer(bit_width, signed_offset)
|
||||
if x_minus_y_dtype.bit_width <= bit_width:
|
||||
x_minus_y = self._create_sub(resulting_type, x, y)
|
||||
sanitized_x_minus_y = self._create_add_clear(resulting_type, x_minus_y, sanitizer)
|
||||
|
||||
zero_position = 2 ** (bit_width - 1)
|
||||
if equals:
|
||||
operation_lut = [int(i == zero_position) for i in range(2**bit_width)]
|
||||
else:
|
||||
operation_lut = [int(i != zero_position) for i in range(2**bit_width)]
|
||||
|
||||
return self._create_tlu(resulting_type, sanitized_x_minus_y, operation_lut)
|
||||
|
||||
chunk_size = int(np.floor(bit_width / 2))
|
||||
number_of_chunks = int(np.ceil(bit_width / chunk_size))
|
||||
|
||||
if x_is_signed != y_is_signed:
|
||||
number_of_chunks += 1
|
||||
|
||||
if self.node.inputs[0].shape != self.node.output.shape:
|
||||
x = self._create_add(resulting_type, x, self._convert_zeros())
|
||||
if self.node.inputs[1].shape != self.node.output.shape:
|
||||
y = self._create_add(resulting_type, y, self._convert_zeros())
|
||||
|
||||
greater_than_half_lut = [
|
||||
int(i >= signed_offset) << (number_of_chunks - 1) for i in range(2**bit_width)
|
||||
]
|
||||
if x_is_unsigned and y_is_signed:
|
||||
is_unsigned_greater_than_half = self._create_tlu(
|
||||
resulting_type,
|
||||
x,
|
||||
greater_than_half_lut,
|
||||
)
|
||||
elif x_is_signed and y_is_unsigned:
|
||||
is_unsigned_greater_than_half = self._create_tlu(
|
||||
resulting_type,
|
||||
y,
|
||||
greater_than_half_lut,
|
||||
)
|
||||
else:
|
||||
is_unsigned_greater_than_half = None
|
||||
|
||||
offset_x_by = 0
|
||||
offset_y_by = 0
|
||||
|
||||
if x_is_signed or y_is_signed:
|
||||
if x_is_signed:
|
||||
x = self._create_add_clear(resulting_type, x, sanitizer)
|
||||
else:
|
||||
offset_x_by = signed_offset
|
||||
|
||||
if y_is_signed:
|
||||
y = self._create_add_clear(resulting_type, y, sanitizer)
|
||||
else:
|
||||
offset_y_by = signed_offset
|
||||
|
||||
carries = self._pack_to_chunk_groups_and_map(
|
||||
resulting_type,
|
||||
comparison_bit_width,
|
||||
chunk_size,
|
||||
x,
|
||||
y,
|
||||
lambda _, x, y: int(x != y),
|
||||
x_offset=offset_x_by,
|
||||
y_offset=offset_y_by,
|
||||
)
|
||||
|
||||
if is_unsigned_greater_than_half:
|
||||
carries.append(is_unsigned_greater_than_half)
|
||||
|
||||
while len(carries) > 1:
|
||||
a = carries.pop()
|
||||
b = carries.pop()
|
||||
|
||||
result = self._create_add(resulting_type, a, b)
|
||||
carries.insert(0, result)
|
||||
|
||||
carry = carries[0]
|
||||
|
||||
return self._create_tlu(
|
||||
resulting_type,
|
||||
carry,
|
||||
[int(i == 0 if equals else i != 0) for i in range(2**bit_width)],
|
||||
)
|
||||
|
||||
def _convert_shift(self, orientation: str) -> OpResult:
|
||||
if not self.all_of_the_inputs_are_encrypted:
|
||||
return self._convert_tlu()
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
x = self.preds[0]
|
||||
b = self.preds[1]
|
||||
|
||||
b_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.inputs[1])
|
||||
|
||||
assert isinstance(self.node.output.dtype, Integer)
|
||||
bit_width = self.node.output.dtype.bit_width
|
||||
|
||||
pred_nodes = self.graph.ordered_preds_of(self.node)
|
||||
|
||||
x_original_bit_width = pred_nodes[0].properties["original_bit_width"]
|
||||
b_original_bit_width = pred_nodes[1].properties["original_bit_width"]
|
||||
|
||||
if self.node.inputs[0].shape != self.node.output.shape:
|
||||
x = self._create_add(resulting_type, x, self._convert_zeros())
|
||||
|
||||
if x_original_bit_width + b_original_bit_width <= bit_width:
|
||||
shift_multiplier = self._create_constant_integer(bit_width, 2**b_original_bit_width)
|
||||
shifted_x = self._create_mul_clear(resulting_type, x, shift_multiplier)
|
||||
packed_x_and_b = self._create_add(resulting_type, shifted_x, b)
|
||||
return self._create_tlu(
|
||||
resulting_type,
|
||||
packed_x_and_b,
|
||||
[
|
||||
(x << b) if orientation == "left" else (x >> b)
|
||||
for x in range(2**x_original_bit_width)
|
||||
for b in range(2**b_original_bit_width)
|
||||
],
|
||||
)
|
||||
|
||||
# Left_shifts of x << b can be done as follows:
|
||||
# - left shift of x by 8 if b & 0b1000 > 0
|
||||
# - left shift of x by 4 if b & 0b0100 > 0
|
||||
# - left shift of x by 2 if b & 0b0010 > 0
|
||||
# - left shift of x by 1 if b & 0b0001 > 0
|
||||
|
||||
# Encoding this condition is non-trivial -- however,
|
||||
# it can be done using the following trick:
|
||||
|
||||
# y = (b & 0b1000 > 0) * ((x << 8) - x) + x
|
||||
|
||||
# When b & 0b1000, then:
|
||||
# y = 1 * ((x << 8) - x) + x = (x << 8) - x + x = x << 8
|
||||
|
||||
# When b & 0b1000 == 0 then:
|
||||
# y = 0 * ((x << 8) - x) + x = x
|
||||
|
||||
# The same trick can be used for right shift but with:
|
||||
# y = x - (b & 0b1000 > 0) * (x - (x >> 8))
|
||||
|
||||
original_bit_width = self.node.properties["original_bit_width"]
|
||||
chunk_size = min(original_bit_width, bit_width - 1)
|
||||
|
||||
for i in reversed(range(b_original_bit_width)):
|
||||
to_check = 2**i
|
||||
|
||||
should_shift = self._create_tlu(
|
||||
b_type,
|
||||
b,
|
||||
[int((b & to_check) > 0) for b in range(2**bit_width)],
|
||||
)
|
||||
shifted_x = self._create_tlu(
|
||||
resulting_type,
|
||||
x,
|
||||
(
|
||||
[(x << to_check) - x for x in range(2**bit_width)]
|
||||
if orientation == "left"
|
||||
else [x - (x >> to_check) for x in range(2**bit_width)]
|
||||
),
|
||||
)
|
||||
|
||||
chunks = []
|
||||
for offset in range(0, original_bit_width, chunk_size):
|
||||
bits_to_process = min(chunk_size, original_bit_width - offset)
|
||||
right_shift_by = original_bit_width - offset - bits_to_process
|
||||
mask = (2**bits_to_process) - 1
|
||||
|
||||
chunk_x = self._create_tlu(
|
||||
resulting_type,
|
||||
shifted_x,
|
||||
[(((x >> right_shift_by) & mask) << 1) for x in range(2**bit_width)],
|
||||
)
|
||||
packed_chunk_x_and_should_shift = self._create_add(
|
||||
resulting_type, chunk_x, should_shift
|
||||
)
|
||||
|
||||
chunk = self._create_tlu(
|
||||
resulting_type,
|
||||
packed_chunk_x_and_should_shift,
|
||||
[
|
||||
(x << right_shift_by) if b else 0
|
||||
for x in range(2**chunk_size)
|
||||
for b in [0, 1]
|
||||
],
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
difference = chunks[0]
|
||||
for chunk in chunks[1:]:
|
||||
difference = self._create_add(resulting_type, difference, chunk)
|
||||
|
||||
x = (
|
||||
self._create_add(resulting_type, difference, x)
|
||||
if orientation == "left"
|
||||
else self._create_sub(resulting_type, x, difference)
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def _pack_to_chunk_groups_and_map(
|
||||
self,
|
||||
resulting_type: Type,
|
||||
bit_width: int,
|
||||
chunk_size: int,
|
||||
x: OpResult,
|
||||
y: OpResult,
|
||||
mapper: Callable,
|
||||
x_offset: int = 0,
|
||||
y_offset: int = 0,
|
||||
) -> List[OpResult]:
|
||||
"""
|
||||
Split x and y into chunks, pack the chunks and map it to another integer.
|
||||
|
||||
Split x and y into chunks of size `chunk_size`.
|
||||
Pack those chunks into an integer and apply `mapper` function to it.
|
||||
Combine those results into a list and return it.
|
||||
|
||||
If `x_offset` (resp. `y_offset`) is provided, execute the function
|
||||
for `x + offset_x_by` (resp. `y + offset_y_by`) instead of `x` and `y`.
|
||||
"""
|
||||
|
||||
result = []
|
||||
for chunk_index, offset in enumerate(range(0, bit_width, chunk_size)):
|
||||
bits_to_process = min(chunk_size, bit_width - offset)
|
||||
right_shift_by = bit_width - offset - bits_to_process
|
||||
mask = (2**bits_to_process) - 1
|
||||
|
||||
chunk_x = self._create_tlu(
|
||||
resulting_type,
|
||||
x,
|
||||
[
|
||||
((((x + x_offset) >> right_shift_by) & mask) << bits_to_process)
|
||||
for x in range(2**bit_width)
|
||||
],
|
||||
)
|
||||
chunk_y = self._create_tlu(
|
||||
resulting_type,
|
||||
y,
|
||||
[((y + y_offset) >> right_shift_by) & mask for y in range(2**bit_width)],
|
||||
)
|
||||
|
||||
packed_chunks = self._create_add(resulting_type, chunk_x, chunk_y)
|
||||
mapped_chunks = self._create_tlu(
|
||||
resulting_type,
|
||||
packed_chunks,
|
||||
[mapper(chunk_index, x, y) for x in range(mask + 1) for y in range(mask + 1)],
|
||||
)
|
||||
|
||||
result.append(mapped_chunks)
|
||||
|
||||
return result
|
||||
|
||||
# pylint: enable=no-self-use,too-many-branches,too-many-locals,too-many-statements
|
||||
|
||||
@@ -363,6 +363,26 @@ class Node:
|
||||
True if the node is converted to a table lookup, False otherwise
|
||||
"""
|
||||
|
||||
if (
|
||||
all(value.is_encrypted for value in self.inputs)
|
||||
and self.operation == Operation.Generic
|
||||
and self.properties["name"]
|
||||
in [
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
"bitwise_xor",
|
||||
"equal",
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"left_shift",
|
||||
"less",
|
||||
"less_equal",
|
||||
"not_equal",
|
||||
"right_shift",
|
||||
]
|
||||
):
|
||||
return False
|
||||
|
||||
return self.operation == Operation.Generic and self.properties["name"] not in [
|
||||
"add",
|
||||
"array",
|
||||
|
||||
@@ -102,7 +102,7 @@ select = [
|
||||
"PLC", "PLE", "PLR", "PLW", "RUF"
|
||||
]
|
||||
ignore = [
|
||||
"A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100",
|
||||
"A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100", "ERA001",
|
||||
"RET504", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003"
|
||||
]
|
||||
|
||||
|
||||
@@ -244,14 +244,24 @@ class Helpers:
|
||||
if not isinstance(sample, list):
|
||||
sample = [sample]
|
||||
|
||||
for i in range(retries):
|
||||
expected = function(*sample)
|
||||
actual = circuit.encrypt_run_decrypt(*sample)
|
||||
def sanitize(values):
|
||||
if not isinstance(values, tuple):
|
||||
values = (values,)
|
||||
|
||||
if not isinstance(expected, tuple):
|
||||
expected = (expected,)
|
||||
if not isinstance(actual, tuple):
|
||||
actual = (actual,)
|
||||
result = []
|
||||
for value in values:
|
||||
if isinstance(value, (bool, np.bool_)):
|
||||
value = int(value)
|
||||
elif isinstance(value, np.ndarray) and value.dtype == np.bool_:
|
||||
value = value.astype(np.int64)
|
||||
|
||||
result.append(value)
|
||||
|
||||
return tuple(result)
|
||||
|
||||
for i in range(retries):
|
||||
expected = sanitize(function(*sample))
|
||||
actual = sanitize(circuit.encrypt_run_decrypt(*sample))
|
||||
|
||||
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
|
||||
break
|
||||
|
||||
62
tests/execution/test_bitwise.py
Normal file
62
tests/execution/test_bitwise.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Tests of execution of bitwise operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x & y,
|
||||
id="x & y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x | y,
|
||||
id="x | y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x ^ y,
|
||||
id="x ^ y",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 255], "status": "encrypted"},
|
||||
"y": {"range": [0, 255], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
|
||||
"y": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_bitwise(function, parameters, helpers):
|
||||
"""
|
||||
Test bitwise operations between encrypted integers.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
161
tests/execution/test_comparison.py
Normal file
161
tests/execution/test_comparison.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Tests of execution of comparison operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x == y,
|
||||
id="x == y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x != y,
|
||||
id="x != y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x < y,
|
||||
id="x < y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x <= y,
|
||||
id="x <= y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x > y,
|
||||
id="x > y",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x >= y,
|
||||
id="x >= y",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 3], "status": "encrypted"},
|
||||
"y": {"range": [0, 3], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 255], "status": "encrypted"},
|
||||
"y": {"range": [0, 255], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-128, 127], "status": "encrypted"},
|
||||
"y": {"range": [-128, 127], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-128, 127], "status": "encrypted"},
|
||||
"y": {"range": [0, 255], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 255], "status": "encrypted"},
|
||||
"y": {"range": [-128, 127], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-8, 7], "status": "encrypted"},
|
||||
"y": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [-8, 7], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_comparison(function, parameters, helpers):
|
||||
"""
|
||||
Test comparison operations between encrypted integers.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: (x == y) + 200,
|
||||
id="(x == y) + 200",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x != y) + 200,
|
||||
id="(x != y) + 200",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x < y) + 200,
|
||||
id="(x < y) + 200",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x <= y) + 200,
|
||||
id="(x <= y) + 200",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x > y) + 200,
|
||||
id="(x > y) + 200",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (x >= y) + 200,
|
||||
id="(x >= y) + 200",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 15], "status": "encrypted"},
|
||||
"y": {"range": [0, 15], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-8, 7], "status": "encrypted"},
|
||||
"y": {"range": [-8, 7], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 15], "status": "encrypted"},
|
||||
"y": {"range": [0, 15], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [-8, 7], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-10, 10], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [-10, 10], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_optimized_comparison(function, parameters, helpers):
|
||||
"""
|
||||
Test comparison operations between encrypted integers with a single TLU.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
178
tests/execution/test_shift.py
Normal file
178
tests/execution/test_shift.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Tests of execution of shift operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x << y,
|
||||
id="x << y",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 1], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 3], "status": "encrypted"},
|
||||
"y": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [0, 3], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_left_shift(function, parameters, helpers):
|
||||
"""
|
||||
Test left shift between encrypted integers.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x >> y,
|
||||
id="x >> y",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 1 << 7], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 1 << 4], "status": "encrypted"},
|
||||
"y": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 1 << 4], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [0, 3], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 1 << 4], "status": "encrypted", "shape": (2,)},
|
||||
"y": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_right_shift(function, parameters, helpers):
|
||||
"""
|
||||
Test right shift between encrypted integers.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x << y,
|
||||
id="x << y",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 1], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_left_shift_coverage(function, parameters, helpers):
|
||||
"""
|
||||
Test left shift between encrypted integers all cases.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
for i in range(2):
|
||||
for j in range(8):
|
||||
helpers.check_execution(circuit, function, [i, j])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x >> y,
|
||||
id="x >> y",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"parameters",
|
||||
[
|
||||
{
|
||||
"x": {"range": [0, 1 << 7], "status": "encrypted"},
|
||||
"y": {"range": [0, 7], "status": "encrypted"},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_right_shift_coverage(function, parameters, helpers):
|
||||
"""
|
||||
Test right shift between encrypted integers all cases.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_execution(circuit, function, [0b11, 0])
|
||||
helpers.check_execution(circuit, function, [0b11, 1])
|
||||
helpers.check_execution(circuit, function, [0b110, 2])
|
||||
helpers.check_execution(circuit, function, [0b1100, 3])
|
||||
helpers.check_execution(circuit, function, [0b11000, 4])
|
||||
helpers.check_execution(circuit, function, [0b110000, 5])
|
||||
helpers.check_execution(circuit, function, [0b110000, 6])
|
||||
helpers.check_execution(circuit, function, [0b1100000, 7])
|
||||
@@ -400,6 +400,40 @@ Subgraphs:
|
||||
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %5
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x << y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(-1, 1), (-2, 3)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<int2> ∈ [-2, -1]
|
||||
%1 = y # EncryptedScalar<uint2> ∈ [1, 3]
|
||||
%2 = left_shift(%0, %1) # EncryptedScalar<int5> ∈ [-16, -2]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned bitwise operations are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x << y,
|
||||
{"x": "encrypted", "y": "encrypted"},
|
||||
[(1, 20), (2, 10)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<uint2> ∈ [1, 2]
|
||||
%1 = y # EncryptedScalar<uint5> ∈ [10, 20]
|
||||
%2 = left_shift(%0, %1) # EncryptedScalar<uint21> ∈ [2048, 1048576]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 4-bit shifts are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user