feat: implement bitwise and comparison operators

This commit is contained in:
poechsel
2023-02-08 13:00:24 +01:00
committed by Umut
parent f3affae84a
commit e126a11fcb
9 changed files with 1327 additions and 19 deletions

View File

@@ -96,6 +96,21 @@ class GraphConverter:
if not inputs[0].is_encrypted:
return "only assignment to encrypted tensors are supported"
elif name in ["bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift"]:
assert_that(len(inputs) == 2)
if all(value.is_encrypted for value in node.inputs):
pred_nodes = graph.ordered_preds_of(node)
if (
name in ["left_shift", "right_shift"]
and cast(Integer, pred_nodes[1].output.dtype).bit_width > 4
):
return "only up to 4-bit shifts are supported"
for pred_node in pred_nodes:
assert isinstance(pred_node.output.dtype, Integer)
if pred_node.output.dtype.is_signed:
return "only unsigned bitwise operations are supported"
elif name == "broadcast_to":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
@@ -115,6 +130,9 @@ class GraphConverter:
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only dot product between encrypted and clear is supported"
elif name in ["equal", "greater", "greater_equal", "less", "less_equal", "not_equal"]:
assert_that(len(inputs) == 2)
elif name == "expand_dims":
assert_that(len(inputs) == 1)
@@ -251,6 +269,20 @@ class GraphConverter:
current_node_bit_width = (
dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width
)
if (
all(value.is_encrypted for value in node.inputs)
and node.operation == Operation.Generic
and node.properties["name"]
in [
"greater",
"greater_equal",
"less",
"less_equal",
]
):
# implementation of these operators require at least 4 bits
current_node_bit_width = max(current_node_bit_width, 4)
if max_bit_width < current_node_bit_width:
max_bit_width = current_node_bit_width
max_bit_width_node = node
@@ -286,7 +318,10 @@ class GraphConverter:
+ graph.format(highlighted_nodes=offending_nodes)
)
for node in graph.graph.nodes:
for node in nx.topological_sort(graph.graph):
assert isinstance(node.output.dtype, Integer)
node.properties["original_bit_width"] = node.output.dtype.bit_width
for value in node.inputs + [node.output]:
dtype = value.dtype
assert_that(isinstance(dtype, Integer))
@@ -335,9 +370,14 @@ class GraphConverter:
variable_input_bit_width = variable_input_dtype.bit_width
offset_constant_dtype = SignedInteger(variable_input_bit_width + 1)
offset_constant = Node.constant(abs(variable_input_dtype.min()))
offset_constant_value = abs(variable_input_dtype.min())
offset_constant = Node.constant(offset_constant_value)
offset_constant.output.dtype = offset_constant_dtype
original_bit_width = Integer.that_can_represent(offset_constant_value).bit_width
offset_constant.properties["original_bit_width"] = original_bit_width
add_offset = Node.generic(
"add",
[variable_input_value, ClearScalar(offset_constant_dtype)],
@@ -345,6 +385,9 @@ class GraphConverter:
np.add,
)
original_bit_width = variable_input_node.properties["original_bit_width"]
add_offset.properties["original_bit_width"] = original_bit_width
nx_graph.remove_edge(variable_input_node, node)
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0)
@@ -412,6 +455,10 @@ class GraphConverter:
kwargs={"newshape": required_value_shape},
)
modified_pred.properties["original_bit_width"] = pred_to_modify.properties[
"original_bit_width"
]
nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0)
nx_graph.remove_edge(pred_to_modify, node)
@@ -448,6 +495,9 @@ class GraphConverter:
lambda: np.zeros((), dtype=np.int64),
)
original_bit_width = 1
zero.properties["original_bit_width"] = original_bit_width
new_assigned_pred = Node.generic(
"add",
[assigned_pred.output, zero.output],
@@ -455,6 +505,9 @@ class GraphConverter:
np.add,
)
original_bit_width = assigned_pred.properties["original_bit_width"]
new_assigned_pred.properties["original_bit_width"] = original_bit_width
nx_graph.remove_edge(preds[1], node)
nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0)
@@ -473,7 +526,24 @@ class GraphConverter:
"""
# pylint: disable=invalid-name
OPS_TO_TENSORIZE = ["add", "broadcast_to", "dot", "multiply", "subtract"]
OPS_TO_TENSORIZE = [
"add",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"broadcast_to",
"dot",
"equal",
"greater",
"greater_equal",
"left_shift",
"less",
"less_equal",
"multiply",
"not_equal",
"right_shift",
"subtract",
]
# pylint: enable=invalid-name
tensorized_scalars: Dict[Node, Node] = {}
@@ -490,6 +560,11 @@ class GraphConverter:
if not node.inputs[0].is_scalar:
continue
# for bitwise and comparison operators that can have constants
# we don't need broadcasting here
if node.converted_to_table_lookup:
continue
pred_to_tensorize: Optional[Node] = None
pred_to_tensorize_index = 0
@@ -513,8 +588,14 @@ class GraphConverter:
tensorized_value,
lambda *args: np.array(args),
)
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
original_bit_width = pred_to_tensorize.properties["original_bit_width"]
tensorized_pred.properties["original_bit_width"] = original_bit_width
original_shape = ()
tensorized_pred.properties["original_shape"] = original_shape
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
tensorized_scalars[pred_to_tensorize] = tensorized_pred
assert tensorized_pred is not None
@@ -522,6 +603,10 @@ class GraphConverter:
nx_graph.remove_edge(pred_to_tensorize, node)
nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index)
new_input_value = deepcopy(node.inputs[pred_to_tensorize_index])
new_input_value.shape = (1,)
node.inputs[pred_to_tensorize_index] = new_input_value
@staticmethod
def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]:
"""

View File

@@ -4,7 +4,9 @@ Declaration of `NodeConverter` class.
# pylint: disable=no-member,no-name-in-module,too-many-lines
from typing import Dict, List, Tuple
import re
from enum import IntEnum
from typing import Callable, Dict, List, Set, Tuple, cast
import numpy as np
from concrete.lang.dialects import fhe, fhelinalg
@@ -34,6 +36,20 @@ from .utils import construct_deduplicated_tables
# pylint: enable=no-member,no-name-in-module
class Comparison(IntEnum):
"""
Comparison enum, to generalize conversion of comparison operators.
Because comparison result has 3 possibilities, Comparison enum is 2 bits.
"""
EQUAL = 0b00
LESS = 0b01
GREATER = 0b10
UNUSED = 0b11
class NodeConverter:
"""
NodeConverter class, to convert computation graph nodes to their MLIR equivalent.
@@ -125,11 +141,12 @@ class NodeConverter:
self.all_of_the_inputs_are_tensors = True
self.one_of_the_inputs_is_a_tensor = False
for inp in node.inputs:
if not inp.is_encrypted:
for pred in graph.ordered_preds_of(node):
if not pred.output.is_encrypted:
self.all_of_the_inputs_are_encrypted = False
if inp.is_scalar:
shape = pred.properties.get("original_shape", pred.output.shape)
if shape == ():
self.all_of_the_inputs_are_tensors = False
else:
self.one_of_the_inputs_is_a_tensor = True
@@ -156,20 +173,31 @@ class NodeConverter:
"add": self._convert_add,
"array": self._convert_array,
"assign.static": self._convert_static_assignment,
"bitwise_and": self._convert_bitwise_and,
"bitwise_or": self._convert_bitwise_or,
"bitwise_xor": self._convert_bitwise_xor,
"broadcast_to": self._convert_broadcast_to,
"concatenate": self._convert_concat,
"conv1d": self._convert_conv1d,
"conv2d": self._convert_conv2d,
"conv3d": self._convert_conv3d,
"dot": self._convert_dot,
"equal": self._convert_equal,
"expand_dims": self._convert_reshape,
"greater": self._convert_greater,
"greater_equal": self._convert_greater_equal,
"index.static": self._convert_static_indexing,
"left_shift": self._convert_left_shift,
"less": self._convert_less,
"less_equal": self._convert_less_equal,
"matmul": self._convert_matmul,
"maxpool": self._convert_maxpool,
"multiply": self._convert_mul,
"negative": self._convert_neg,
"not_equal": self._convert_not_equal,
"ones": self._convert_ones,
"reshape": self._convert_reshape,
"right_shift": self._convert_right_shift,
"squeeze": self._convert_squeeze,
"subtract": self._convert_sub,
"sum": self._convert_sum,
@@ -183,7 +211,7 @@ class NodeConverter:
assert_that(self.node.converted_to_table_lookup)
return self._convert_tlu()
# pylint: disable=no-self-use
# pylint: disable=no-self-use,too-many-branches,too-many-locals,too-many-statements
def _convert_add(self) -> OpResult:
"""
@@ -277,6 +305,39 @@ class NodeConverter:
self.from_elements_operations[placeholder_result] = processed_preds
return placeholder_result
def _convert_bitwise_and(self) -> OpResult:
"""
Convert "bitwise_and" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_bitwise(lambda x, y: x & y)
def _convert_bitwise_or(self) -> OpResult:
"""
Convert "bitwise_or" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_bitwise(lambda x, y: x | y)
def _convert_bitwise_xor(self) -> OpResult:
"""
Convert "bitwise_xor" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_bitwise(lambda x, y: x ^ y)
def _convert_concat(self) -> OpResult:
"""
Convert "concatenate" node to its corresponding MLIR representation.
@@ -462,6 +523,76 @@ class NodeConverter:
return result
def _convert_equal(self) -> OpResult:
"""
Convert "equal" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_equality(equals=True)
def _convert_greater(self) -> OpResult:
"""
Convert "greater" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_compare(invert_operands=True, accept={Comparison.LESS})
def _convert_greater_equal(self) -> OpResult:
"""
Convert "greater_equal" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_compare(
invert_operands=True, accept={Comparison.LESS, Comparison.EQUAL}
)
def _convert_left_shift(self) -> OpResult:
"""
Convert "left_shift" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_shift(orientation="left")
def _convert_less(self) -> OpResult:
"""
Convert "less" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_compare(invert_operands=False, accept={Comparison.LESS})
def _convert_less_equal(self) -> OpResult:
"""
Convert "less_equal" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_compare(
invert_operands=False, accept={Comparison.LESS, Comparison.EQUAL}
)
def _convert_matmul(self) -> OpResult:
"""Convert a MatMul node to its corresponding MLIR representation.
@@ -537,6 +668,17 @@ class NodeConverter:
return result
def _convert_not_equal(self) -> OpResult:
"""
Convert "not_equal" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_equality(equals=False)
def _convert_ones(self) -> OpResult:
"""
Convert "ones" node to its corresponding MLIR representation.
@@ -699,6 +841,17 @@ class NodeConverter:
),
).result
def _convert_right_shift(self) -> OpResult:
"""
Convert "right_shift" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
return self._convert_shift(orientation="right")
def _convert_static_assignment(self) -> OpResult:
"""
Convert "assign.static" node to its corresponding MLIR representation.
@@ -1094,7 +1247,19 @@ class NodeConverter:
return result
def _create_constant(self, mlir_type: Type, mlir_attribute: Attribute):
def _create_add(self, resulting_type, a, b) -> OpResult:
if self.one_of_the_inputs_is_a_tensor:
return fhelinalg.AddEintOp(resulting_type, a, b).result
return fhe.AddEintOp(resulting_type, a, b).result
def _create_add_clear(self, resulting_type, a, b) -> OpResult:
if self.one_of_the_inputs_is_a_tensor:
return fhelinalg.AddEintIntOp(resulting_type, a, b).result
return fhe.AddEintIntOp(resulting_type, a, b).result
def _create_constant(self, mlir_type: Type, mlir_attribute: Attribute) -> OpResult:
result = self.constant_cache.get((mlir_type, mlir_attribute))
if result is None:
# ConstantOp is being decorated, and the init function is supposed to take more
@@ -1105,4 +1270,597 @@ class NodeConverter:
self.constant_cache[(mlir_type, mlir_attribute)] = result
return result
# pylint: enable=no-self-use
def _create_constant_integer(self, bit_width, x) -> OpResult:
constant_value = Value(
Integer(is_signed=True, bit_width=bit_width + 1),
shape=(1,) if self.one_of_the_inputs_is_a_tensor else (),
is_encrypted=False,
)
constant_type = NodeConverter.value_to_mlir_type(self.ctx, constant_value)
if self.one_of_the_inputs_is_a_tensor:
return self._create_constant(
constant_type,
Attribute.parse(f"dense<{[int(x)]}> : {constant_type}"),
)
return self._create_constant(constant_type, IntegerAttr.get(constant_type, x))
def _create_mul_clear(self, resulting_type, a, b) -> OpResult:
if self.one_of_the_inputs_is_a_tensor:
return fhelinalg.MulEintIntOp(resulting_type, a, b).result
return fhe.MulEintIntOp(resulting_type, a, b).result
def _create_sub(self, resulting_type, a, b) -> OpResult:
if self.one_of_the_inputs_is_a_tensor:
return fhelinalg.SubEintOp(resulting_type, a, b).result
return fhe.SubEintOp(resulting_type, a, b).result
def _create_tlu(self, resulting_type, pred, lut_values) -> OpResult:
resulting_type_str = str(resulting_type)
bit_width_search = re.search(r"FHE\.eint<([0-9]+)>", resulting_type_str)
assert bit_width_search is not None
bit_width_str = bit_width_search.group(1)
bit_width = int(bit_width_str)
lut_values += [0] * ((2**bit_width) - len(lut_values))
lut_element_type = IntegerType.get_signless(64, context=self.ctx)
lut_type = RankedTensorType.get((len(lut_values),), lut_element_type)
lut_attr = Attribute.parse(f"dense<{str(lut_values)}> : {lut_type}")
lut = self._create_constant(resulting_type, lut_attr).result
if self.one_of_the_inputs_is_a_tensor:
return fhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
return fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
def _convert_bitwise(self, operation: Callable) -> OpResult:
"""
Convert `bitwise_{and,or,xor}` node to their corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
if not self.all_of_the_inputs_are_encrypted:
return self._convert_tlu()
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
x = self.preds[0]
y = self.preds[1]
x_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.inputs[0])
y_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.inputs[1])
assert isinstance(self.node.output.dtype, Integer)
bit_width = self.node.output.dtype.bit_width
chunk_size = int(np.floor(bit_width / 2))
mask = (2**chunk_size) - 1
original_bit_width = max(
pred_node.properties["original_bit_width"]
for pred_node in self.graph.ordered_preds_of(self.node)
)
chunks = []
for offset in range(0, original_bit_width, chunk_size):
x_lut = [((x >> offset) & mask) << chunk_size for x in range(2**bit_width)]
y_lut = [(y >> offset) & mask for y in range(2**bit_width)]
x_chunk = self._create_tlu(x_type, x, x_lut)
y_chunk = self._create_tlu(y_type, y, y_lut)
packed_x_and_y_chunks = self._create_add(resulting_type, x_chunk, y_chunk)
result_chunk = self._create_tlu(
resulting_type,
packed_x_and_y_chunks,
[
operation(x, y) << offset
for x in range(2**chunk_size)
for y in range(2**chunk_size)
],
)
chunks.append(result_chunk)
# add all chunks together in a tree to maximize dataflow parallelization
# c1 c2 c3 c4
# \/ \/
# s2 s1
# \ /
# \ /
# s3
while len(chunks) > 1:
a = chunks.pop()
b = chunks.pop()
result = self._create_add(resulting_type, a, b)
chunks.insert(0, result)
return chunks[0]
def _convert_compare(self, accept: Set[Comparison], invert_operands: bool = False) -> OpResult:
if not self.all_of_the_inputs_are_encrypted:
return self._convert_tlu()
inputs = self.node.inputs
preds = self.preds
if invert_operands:
inputs = inputs[::-1]
preds = preds[::-1]
x = preds[0]
y = preds[1]
x_is_signed = cast(Integer, inputs[0].dtype).is_signed
y_is_signed = cast(Integer, inputs[1].dtype).is_signed
x_is_unsigned = not x_is_signed
y_is_unsigned = not y_is_signed
pred_nodes = self.graph.ordered_preds_of(self.node)
if invert_operands:
pred_nodes = list(reversed(pred_nodes))
x_bit_width, y_bit_width = [node.properties["original_bit_width"] for node in pred_nodes]
comparison_bit_width = max(x_bit_width, y_bit_width)
x_dtype = Integer(is_signed=x_is_signed, bit_width=x_bit_width)
y_dtype = Integer(is_signed=y_is_signed, bit_width=y_bit_width)
x_minus_y_min = x_dtype.min() - y_dtype.max()
x_minus_y_max = x_dtype.max() - y_dtype.min()
x_minus_y_range = [x_minus_y_min, x_minus_y_max]
x_minus_y_dtype = Integer.that_can_represent(x_minus_y_range)
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
bit_width = cast(Integer, self.node.output.dtype).bit_width
signed_offset = 2 ** (bit_width - 1)
sanitizer = self._create_constant_integer(bit_width, signed_offset)
if x_minus_y_dtype.bit_width <= bit_width:
x_minus_y = self._create_sub(resulting_type, x, y)
sanitized_x_minus_y = self._create_add_clear(resulting_type, x_minus_y, sanitizer)
accept_less = int(Comparison.LESS in accept)
accept_equal = int(Comparison.EQUAL in accept)
accept_greater = int(Comparison.GREATER in accept)
all_cells = 2**bit_width
less_cells = 2 ** (bit_width - 1)
equal_cells = 1
greater_cells = less_cells - equal_cells
assert less_cells + equal_cells + greater_cells == all_cells
table = (
[accept_less] * less_cells
+ [accept_equal] * equal_cells
+ [accept_greater] * greater_cells
)
return self._create_tlu(resulting_type, sanitized_x_minus_y, table)
# Comparison between signed and unsigned is tricky.
# To deal with them, we add -min of the signed number to both operands
# such that they are both positive. To avoid overflowing
# the unsigned operand this addition is done "virtually"
# while constructing one of the luts.
# A flag ("is_unsigned_greater_than_half") is emitted in MLIR to keep track
# if the unsigned operand was greater than the max signed number as it
# is needed to determine the result of the comparison.
# Exemple: to compare x and y where x is an int3 and y and uint3, when y
# is greater than 4 we are sure than x will be less than x.
if inputs[0].shape != self.node.output.shape:
x = self._create_add(resulting_type, x, self._convert_zeros())
if inputs[1].shape != self.node.output.shape:
y = self._create_add(resulting_type, y, self._convert_zeros())
offset_x_by = 0
offset_y_by = 0
if x_is_signed or y_is_signed:
if x_is_signed:
x = self._create_add_clear(resulting_type, x, sanitizer)
else:
offset_x_by = signed_offset
if y_is_signed:
y = self._create_add_clear(resulting_type, y, sanitizer)
else:
offset_y_by = signed_offset
def compare(x, y):
if x < y:
return Comparison.LESS
if x > y:
return Comparison.GREATER
return Comparison.EQUAL
chunk_size = int(np.floor(bit_width / 2))
carries = self._pack_to_chunk_groups_and_map(
resulting_type,
comparison_bit_width,
chunk_size,
x,
y,
lambda i, x, y: compare(x, y) << (min(i, 1) * 2),
x_offset=offset_x_by,
y_offset=offset_y_by,
)
# This is the reduction step -- we have an array where the entry i is the
# result of comparing the chunks of x and y at position i.
all_comparisons = [Comparison.EQUAL, Comparison.LESS, Comparison.GREATER, Comparison.UNUSED]
pick_first_not_equal_lut = [
int(
current_comparison
if previous_comparison == Comparison.EQUAL
else previous_comparison
)
for current_comparison in all_comparisons
for previous_comparison in all_comparisons
]
carry = carries[0]
for next_carry in carries[1:]:
combined_carries = self._create_add(resulting_type, next_carry, carry)
carry = self._create_tlu(resulting_type, combined_carries, pick_first_not_equal_lut)
if x_is_signed != y_is_signed:
carry_bit_width = 2
is_less_mask = int(Comparison.LESS)
is_unsigned_greater_than_half = self._create_tlu(
resulting_type,
x if x_is_unsigned else y,
[int(value >= signed_offset) << carry_bit_width for value in range(2**bit_width)],
)
packed_carry_and_is_unsigned_greater_than_half = self._create_add(
resulting_type,
is_unsigned_greater_than_half,
carry,
)
# this function is actually converting either
# - lhs < rhs
# - lhs <= rhs
# in the implementation, we call
# - x = lhs
# - y = rhs
# so if y is unsigned and greater than half
# - y is definitely bigger than x
# - is_unsigned_greater_than_half == 1
# - result == (lhs < rhs) == (x < y) == 1
# so if x is unsigned and greater than half
# - x is definitely bigger than y
# - is_unsigned_greater_than_half == 1
# - result == (lhs < rhs) == (x < y) == 0
if y_is_unsigned:
result_table = [
1 if (i >> carry_bit_width) else (i & is_less_mask) for i in range(2**3)
]
else:
result_table = [
0 if (i >> carry_bit_width) else (i & is_less_mask) for i in range(2**3)
]
result = self._create_tlu(
resulting_type,
packed_carry_and_is_unsigned_greater_than_half,
result_table,
)
else:
boolean_result_lut = [int(comparison in accept) for comparison in all_comparisons]
result = self._create_tlu(resulting_type, carry, boolean_result_lut)
return result
def _convert_equality(self, equals: bool) -> OpResult:
if not self.all_of_the_inputs_are_encrypted:
return self._convert_tlu()
x = self.preds[0]
y = self.preds[1]
x_is_signed = cast(Integer, self.node.inputs[0].dtype).is_signed
y_is_signed = cast(Integer, self.node.inputs[1].dtype).is_signed
x_is_unsigned = not x_is_signed
y_is_unsigned = not y_is_signed
x_bit_width, y_bit_width = [
node.properties["original_bit_width"] for node in self.graph.ordered_preds_of(self.node)
]
comparison_bit_width = max(x_bit_width, y_bit_width)
x_dtype = Integer(is_signed=x_is_signed, bit_width=x_bit_width)
y_dtype = Integer(is_signed=y_is_signed, bit_width=y_bit_width)
x_minus_y_min = x_dtype.min() - y_dtype.max()
x_minus_y_max = x_dtype.max() - y_dtype.min()
x_minus_y_range = [x_minus_y_min, x_minus_y_max]
x_minus_y_dtype = Integer.that_can_represent(x_minus_y_range)
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
bit_width = cast(Integer, self.node.output.dtype).bit_width
signed_offset = 2 ** (bit_width - 1)
sanitizer = self._create_constant_integer(bit_width, signed_offset)
if x_minus_y_dtype.bit_width <= bit_width:
x_minus_y = self._create_sub(resulting_type, x, y)
sanitized_x_minus_y = self._create_add_clear(resulting_type, x_minus_y, sanitizer)
zero_position = 2 ** (bit_width - 1)
if equals:
operation_lut = [int(i == zero_position) for i in range(2**bit_width)]
else:
operation_lut = [int(i != zero_position) for i in range(2**bit_width)]
return self._create_tlu(resulting_type, sanitized_x_minus_y, operation_lut)
chunk_size = int(np.floor(bit_width / 2))
number_of_chunks = int(np.ceil(bit_width / chunk_size))
if x_is_signed != y_is_signed:
number_of_chunks += 1
if self.node.inputs[0].shape != self.node.output.shape:
x = self._create_add(resulting_type, x, self._convert_zeros())
if self.node.inputs[1].shape != self.node.output.shape:
y = self._create_add(resulting_type, y, self._convert_zeros())
greater_than_half_lut = [
int(i >= signed_offset) << (number_of_chunks - 1) for i in range(2**bit_width)
]
if x_is_unsigned and y_is_signed:
is_unsigned_greater_than_half = self._create_tlu(
resulting_type,
x,
greater_than_half_lut,
)
elif x_is_signed and y_is_unsigned:
is_unsigned_greater_than_half = self._create_tlu(
resulting_type,
y,
greater_than_half_lut,
)
else:
is_unsigned_greater_than_half = None
offset_x_by = 0
offset_y_by = 0
if x_is_signed or y_is_signed:
if x_is_signed:
x = self._create_add_clear(resulting_type, x, sanitizer)
else:
offset_x_by = signed_offset
if y_is_signed:
y = self._create_add_clear(resulting_type, y, sanitizer)
else:
offset_y_by = signed_offset
carries = self._pack_to_chunk_groups_and_map(
resulting_type,
comparison_bit_width,
chunk_size,
x,
y,
lambda _, x, y: int(x != y),
x_offset=offset_x_by,
y_offset=offset_y_by,
)
if is_unsigned_greater_than_half:
carries.append(is_unsigned_greater_than_half)
while len(carries) > 1:
a = carries.pop()
b = carries.pop()
result = self._create_add(resulting_type, a, b)
carries.insert(0, result)
carry = carries[0]
return self._create_tlu(
resulting_type,
carry,
[int(i == 0 if equals else i != 0) for i in range(2**bit_width)],
)
def _convert_shift(self, orientation: str) -> OpResult:
if not self.all_of_the_inputs_are_encrypted:
return self._convert_tlu()
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
x = self.preds[0]
b = self.preds[1]
b_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.inputs[1])
assert isinstance(self.node.output.dtype, Integer)
bit_width = self.node.output.dtype.bit_width
pred_nodes = self.graph.ordered_preds_of(self.node)
x_original_bit_width = pred_nodes[0].properties["original_bit_width"]
b_original_bit_width = pred_nodes[1].properties["original_bit_width"]
if self.node.inputs[0].shape != self.node.output.shape:
x = self._create_add(resulting_type, x, self._convert_zeros())
if x_original_bit_width + b_original_bit_width <= bit_width:
shift_multiplier = self._create_constant_integer(bit_width, 2**b_original_bit_width)
shifted_x = self._create_mul_clear(resulting_type, x, shift_multiplier)
packed_x_and_b = self._create_add(resulting_type, shifted_x, b)
return self._create_tlu(
resulting_type,
packed_x_and_b,
[
(x << b) if orientation == "left" else (x >> b)
for x in range(2**x_original_bit_width)
for b in range(2**b_original_bit_width)
],
)
# Left_shifts of x << b can be done as follows:
# - left shift of x by 8 if b & 0b1000 > 0
# - left shift of x by 4 if b & 0b0100 > 0
# - left shift of x by 2 if b & 0b0010 > 0
# - left shift of x by 1 if b & 0b0001 > 0
# Encoding this condition is non-trivial -- however,
# it can be done using the following trick:
# y = (b & 0b1000 > 0) * ((x << 8) - x) + x
# When b & 0b1000, then:
# y = 1 * ((x << 8) - x) + x = (x << 8) - x + x = x << 8
# When b & 0b1000 == 0 then:
# y = 0 * ((x << 8) - x) + x = x
# The same trick can be used for right shift but with:
# y = x - (b & 0b1000 > 0) * (x - (x >> 8))
original_bit_width = self.node.properties["original_bit_width"]
chunk_size = min(original_bit_width, bit_width - 1)
for i in reversed(range(b_original_bit_width)):
to_check = 2**i
should_shift = self._create_tlu(
b_type,
b,
[int((b & to_check) > 0) for b in range(2**bit_width)],
)
shifted_x = self._create_tlu(
resulting_type,
x,
(
[(x << to_check) - x for x in range(2**bit_width)]
if orientation == "left"
else [x - (x >> to_check) for x in range(2**bit_width)]
),
)
chunks = []
for offset in range(0, original_bit_width, chunk_size):
bits_to_process = min(chunk_size, original_bit_width - offset)
right_shift_by = original_bit_width - offset - bits_to_process
mask = (2**bits_to_process) - 1
chunk_x = self._create_tlu(
resulting_type,
shifted_x,
[(((x >> right_shift_by) & mask) << 1) for x in range(2**bit_width)],
)
packed_chunk_x_and_should_shift = self._create_add(
resulting_type, chunk_x, should_shift
)
chunk = self._create_tlu(
resulting_type,
packed_chunk_x_and_should_shift,
[
(x << right_shift_by) if b else 0
for x in range(2**chunk_size)
for b in [0, 1]
],
)
chunks.append(chunk)
difference = chunks[0]
for chunk in chunks[1:]:
difference = self._create_add(resulting_type, difference, chunk)
x = (
self._create_add(resulting_type, difference, x)
if orientation == "left"
else self._create_sub(resulting_type, x, difference)
)
return x
def _pack_to_chunk_groups_and_map(
self,
resulting_type: Type,
bit_width: int,
chunk_size: int,
x: OpResult,
y: OpResult,
mapper: Callable,
x_offset: int = 0,
y_offset: int = 0,
) -> List[OpResult]:
"""
Split x and y into chunks, pack the chunks and map it to another integer.
Split x and y into chunks of size `chunk_size`.
Pack those chunks into an integer and apply `mapper` function to it.
Combine those results into a list and return it.
If `x_offset` (resp. `y_offset`) is provided, execute the function
for `x + offset_x_by` (resp. `y + offset_y_by`) instead of `x` and `y`.
"""
result = []
for chunk_index, offset in enumerate(range(0, bit_width, chunk_size)):
bits_to_process = min(chunk_size, bit_width - offset)
right_shift_by = bit_width - offset - bits_to_process
mask = (2**bits_to_process) - 1
chunk_x = self._create_tlu(
resulting_type,
x,
[
((((x + x_offset) >> right_shift_by) & mask) << bits_to_process)
for x in range(2**bit_width)
],
)
chunk_y = self._create_tlu(
resulting_type,
y,
[((y + y_offset) >> right_shift_by) & mask for y in range(2**bit_width)],
)
packed_chunks = self._create_add(resulting_type, chunk_x, chunk_y)
mapped_chunks = self._create_tlu(
resulting_type,
packed_chunks,
[mapper(chunk_index, x, y) for x in range(mask + 1) for y in range(mask + 1)],
)
result.append(mapped_chunks)
return result
# pylint: enable=no-self-use,too-many-branches,too-many-locals,too-many-statements

View File

@@ -363,6 +363,26 @@ class Node:
True if the node is converted to a table lookup, False otherwise
"""
if (
all(value.is_encrypted for value in self.inputs)
and self.operation == Operation.Generic
and self.properties["name"]
in [
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"equal",
"greater",
"greater_equal",
"left_shift",
"less",
"less_equal",
"not_equal",
"right_shift",
]
):
return False
return self.operation == Operation.Generic and self.properties["name"] not in [
"add",
"array",

View File

@@ -102,7 +102,7 @@ select = [
"PLC", "PLE", "PLR", "PLW", "RUF"
]
ignore = [
"A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100",
"A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100", "ERA001",
"RET504", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003"
]

View File

@@ -244,14 +244,24 @@ class Helpers:
if not isinstance(sample, list):
sample = [sample]
for i in range(retries):
expected = function(*sample)
actual = circuit.encrypt_run_decrypt(*sample)
def sanitize(values):
if not isinstance(values, tuple):
values = (values,)
if not isinstance(expected, tuple):
expected = (expected,)
if not isinstance(actual, tuple):
actual = (actual,)
result = []
for value in values:
if isinstance(value, (bool, np.bool_)):
value = int(value)
elif isinstance(value, np.ndarray) and value.dtype == np.bool_:
value = value.astype(np.int64)
result.append(value)
return tuple(result)
for i in range(retries):
expected = sanitize(function(*sample))
actual = sanitize(circuit.encrypt_run_decrypt(*sample))
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
break

View File

@@ -0,0 +1,62 @@
"""
Tests of execution of bitwise operations.
"""
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: x & y,
id="x & y",
),
pytest.param(
lambda x, y: x | y,
id="x | y",
),
pytest.param(
lambda x, y: x ^ y,
id="x ^ y",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 255], "status": "encrypted"},
"y": {"range": [0, 255], "status": "encrypted"},
},
{
"x": {"range": [0, 7], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
"y": {"range": [0, 7], "status": "encrypted"},
},
{
"x": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
"y": {"range": [0, 7], "status": "encrypted", "shape": (3,)},
},
],
)
def test_bitwise(function, parameters, helpers):
"""
Test bitwise operations between encrypted integers.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -0,0 +1,161 @@
"""
Tests of execution of comparison operations.
"""
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: x == y,
id="x == y",
),
pytest.param(
lambda x, y: x != y,
id="x != y",
),
pytest.param(
lambda x, y: x < y,
id="x < y",
),
pytest.param(
lambda x, y: x <= y,
id="x <= y",
),
pytest.param(
lambda x, y: x > y,
id="x > y",
),
pytest.param(
lambda x, y: x >= y,
id="x >= y",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 3], "status": "encrypted"},
"y": {"range": [0, 3], "status": "encrypted"},
},
{
"x": {"range": [0, 255], "status": "encrypted"},
"y": {"range": [0, 255], "status": "encrypted"},
},
{
"x": {"range": [-128, 127], "status": "encrypted"},
"y": {"range": [-128, 127], "status": "encrypted"},
},
{
"x": {"range": [-128, 127], "status": "encrypted"},
"y": {"range": [0, 255], "status": "encrypted"},
},
{
"x": {"range": [0, 255], "status": "encrypted"},
"y": {"range": [-128, 127], "status": "encrypted"},
},
{
"x": {"range": [-8, 7], "status": "encrypted"},
"y": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
},
{
"x": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
"y": {"range": [-8, 7], "status": "encrypted"},
},
{
"x": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
"y": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
},
],
)
def test_comparison(function, parameters, helpers):
"""
Test comparison operations between encrypted integers.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: (x == y) + 200,
id="(x == y) + 200",
),
pytest.param(
lambda x, y: (x != y) + 200,
id="(x != y) + 200",
),
pytest.param(
lambda x, y: (x < y) + 200,
id="(x < y) + 200",
),
pytest.param(
lambda x, y: (x <= y) + 200,
id="(x <= y) + 200",
),
pytest.param(
lambda x, y: (x > y) + 200,
id="(x > y) + 200",
),
pytest.param(
lambda x, y: (x >= y) + 200,
id="(x >= y) + 200",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 15], "status": "encrypted"},
"y": {"range": [0, 15], "status": "encrypted"},
},
{
"x": {"range": [-8, 7], "status": "encrypted"},
"y": {"range": [-8, 7], "status": "encrypted"},
},
{
"x": {"range": [0, 15], "status": "encrypted"},
"y": {"range": [0, 15], "status": "encrypted", "shape": (2,)},
},
{
"x": {"range": [-8, 7], "status": "encrypted", "shape": (2,)},
"y": {"range": [-8, 7], "status": "encrypted"},
},
{
"x": {"range": [-10, 10], "status": "encrypted", "shape": (2,)},
"y": {"range": [-10, 10], "status": "encrypted", "shape": (2,)},
},
],
)
def test_optimized_comparison(function, parameters, helpers):
"""
Test comparison operations between encrypted integers with a single TLU.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -0,0 +1,178 @@
"""
Tests of execution of shift operations.
"""
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: x << y,
id="x << y",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 1], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted"},
},
{
"x": {"range": [0, 3], "status": "encrypted"},
"y": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
},
{
"x": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
"y": {"range": [0, 3], "status": "encrypted"},
},
{
"x": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
"y": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
},
],
)
def test_left_shift(function, parameters, helpers):
"""
Test left shift between encrypted integers.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: x >> y,
id="x >> y",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 1 << 7], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted"},
},
{
"x": {"range": [0, 1 << 4], "status": "encrypted"},
"y": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
},
{
"x": {"range": [0, 1 << 4], "status": "encrypted", "shape": (2,)},
"y": {"range": [0, 3], "status": "encrypted"},
},
{
"x": {"range": [0, 1 << 4], "status": "encrypted", "shape": (2,)},
"y": {"range": [0, 3], "status": "encrypted", "shape": (2,)},
},
],
)
def test_right_shift(function, parameters, helpers):
"""
Test right shift between encrypted integers.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: x << y,
id="x << y",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 1], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted"},
},
],
)
def test_left_shift_coverage(function, parameters, helpers):
"""
Test left shift between encrypted integers all cases.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
for i in range(2):
for j in range(8):
helpers.check_execution(circuit, function, [i, j])
@pytest.mark.parametrize(
"function",
[
pytest.param(
lambda x, y: x >> y,
id="x >> y",
),
],
)
@pytest.mark.parametrize(
"parameters",
[
{
"x": {"range": [0, 1 << 7], "status": "encrypted"},
"y": {"range": [0, 7], "status": "encrypted"},
},
],
)
def test_right_shift_coverage(function, parameters, helpers):
"""
Test right shift between encrypted integers all cases.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
helpers.check_execution(circuit, function, [0b11, 0])
helpers.check_execution(circuit, function, [0b11, 1])
helpers.check_execution(circuit, function, [0b110, 2])
helpers.check_execution(circuit, function, [0b1100, 3])
helpers.check_execution(circuit, function, [0b11000, 4])
helpers.check_execution(circuit, function, [0b110000, 5])
helpers.check_execution(circuit, function, [0b110000, 6])
helpers.check_execution(circuit, function, [0b1100000, 7])

View File

@@ -400,6 +400,40 @@ Subgraphs:
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
return %5
""", # noqa: E501
),
pytest.param(
lambda x, y: x << y,
{"x": "encrypted", "y": "encrypted"},
[(-1, 1), (-2, 3)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<int2> ∈ [-2, -1]
%1 = y # EncryptedScalar<uint2> ∈ [1, 3]
%2 = left_shift(%0, %1) # EncryptedScalar<int5> ∈ [-16, -2]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned bitwise operations are supported
return %2
""", # noqa: E501
),
pytest.param(
lambda x, y: x << y,
{"x": "encrypted", "y": "encrypted"},
[(1, 20), (2, 10)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<uint2> ∈ [1, 2]
%1 = y # EncryptedScalar<uint5> ∈ [10, 20]
%2 = left_shift(%0, %1) # EncryptedScalar<uint21> ∈ [2048, 1048576]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 4-bit shifts are supported
return %2
""", # noqa: E501
),
],