diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 5d5950a90..4c20a7c34 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -96,6 +96,21 @@ class GraphConverter: if not inputs[0].is_encrypted: return "only assignment to encrypted tensors are supported" + elif name in ["bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift"]: + assert_that(len(inputs) == 2) + if all(value.is_encrypted for value in node.inputs): + pred_nodes = graph.ordered_preds_of(node) + if ( + name in ["left_shift", "right_shift"] + and cast(Integer, pred_nodes[1].output.dtype).bit_width > 4 + ): + return "only up to 4-bit shifts are supported" + + for pred_node in pred_nodes: + assert isinstance(pred_node.output.dtype, Integer) + if pred_node.output.dtype.is_signed: + return "only unsigned bitwise operations are supported" + elif name == "broadcast_to": assert_that(len(inputs) == 1) if not inputs[0].is_encrypted: @@ -115,6 +130,9 @@ class GraphConverter: if inputs[0].is_encrypted and inputs[1].is_encrypted: return "only dot product between encrypted and clear is supported" + elif name in ["equal", "greater", "greater_equal", "less", "less_equal", "not_equal"]: + assert_that(len(inputs) == 2) + elif name == "expand_dims": assert_that(len(inputs) == 1) @@ -251,6 +269,20 @@ class GraphConverter: current_node_bit_width = ( dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width ) + if ( + all(value.is_encrypted for value in node.inputs) + and node.operation == Operation.Generic + and node.properties["name"] + in [ + "greater", + "greater_equal", + "less", + "less_equal", + ] + ): + # implementation of these operators require at least 4 bits + current_node_bit_width = max(current_node_bit_width, 4) + if max_bit_width < current_node_bit_width: max_bit_width = current_node_bit_width max_bit_width_node = node @@ -286,7 +318,10 @@ class GraphConverter: + graph.format(highlighted_nodes=offending_nodes) ) - for node in graph.graph.nodes: + for node in nx.topological_sort(graph.graph): + assert isinstance(node.output.dtype, Integer) + node.properties["original_bit_width"] = node.output.dtype.bit_width + for value in node.inputs + [node.output]: dtype = value.dtype assert_that(isinstance(dtype, Integer)) @@ -335,9 +370,14 @@ class GraphConverter: variable_input_bit_width = variable_input_dtype.bit_width offset_constant_dtype = SignedInteger(variable_input_bit_width + 1) - offset_constant = Node.constant(abs(variable_input_dtype.min())) + offset_constant_value = abs(variable_input_dtype.min()) + + offset_constant = Node.constant(offset_constant_value) offset_constant.output.dtype = offset_constant_dtype + original_bit_width = Integer.that_can_represent(offset_constant_value).bit_width + offset_constant.properties["original_bit_width"] = original_bit_width + add_offset = Node.generic( "add", [variable_input_value, ClearScalar(offset_constant_dtype)], @@ -345,6 +385,9 @@ class GraphConverter: np.add, ) + original_bit_width = variable_input_node.properties["original_bit_width"] + add_offset.properties["original_bit_width"] = original_bit_width + nx_graph.remove_edge(variable_input_node, node) nx_graph.add_edge(variable_input_node, add_offset, input_idx=0) @@ -412,6 +455,10 @@ class GraphConverter: kwargs={"newshape": required_value_shape}, ) + modified_pred.properties["original_bit_width"] = pred_to_modify.properties[ + "original_bit_width" + ] + nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0) nx_graph.remove_edge(pred_to_modify, node) @@ -448,6 +495,9 @@ class GraphConverter: lambda: np.zeros((), dtype=np.int64), ) + original_bit_width = 1 + zero.properties["original_bit_width"] = original_bit_width + new_assigned_pred = Node.generic( "add", [assigned_pred.output, zero.output], @@ -455,6 +505,9 @@ class GraphConverter: np.add, ) + original_bit_width = assigned_pred.properties["original_bit_width"] + new_assigned_pred.properties["original_bit_width"] = original_bit_width + nx_graph.remove_edge(preds[1], node) nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0) @@ -473,7 +526,24 @@ class GraphConverter: """ # pylint: disable=invalid-name - OPS_TO_TENSORIZE = ["add", "broadcast_to", "dot", "multiply", "subtract"] + OPS_TO_TENSORIZE = [ + "add", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "broadcast_to", + "dot", + "equal", + "greater", + "greater_equal", + "left_shift", + "less", + "less_equal", + "multiply", + "not_equal", + "right_shift", + "subtract", + ] # pylint: enable=invalid-name tensorized_scalars: Dict[Node, Node] = {} @@ -490,6 +560,11 @@ class GraphConverter: if not node.inputs[0].is_scalar: continue + # for bitwise and comparison operators that can have constants + # we don't need broadcasting here + if node.converted_to_table_lookup: + continue + pred_to_tensorize: Optional[Node] = None pred_to_tensorize_index = 0 @@ -513,8 +588,14 @@ class GraphConverter: tensorized_value, lambda *args: np.array(args), ) - nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0) + original_bit_width = pred_to_tensorize.properties["original_bit_width"] + tensorized_pred.properties["original_bit_width"] = original_bit_width + + original_shape = () + tensorized_pred.properties["original_shape"] = original_shape + + nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0) tensorized_scalars[pred_to_tensorize] = tensorized_pred assert tensorized_pred is not None @@ -522,6 +603,10 @@ class GraphConverter: nx_graph.remove_edge(pred_to_tensorize, node) nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index) + new_input_value = deepcopy(node.inputs[pred_to_tensorize_index]) + new_input_value.shape = (1,) + node.inputs[pred_to_tensorize_index] = new_input_value + @staticmethod def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]: """ diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index 54d251486..cce55b58d 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -4,7 +4,9 @@ Declaration of `NodeConverter` class. # pylint: disable=no-member,no-name-in-module,too-many-lines -from typing import Dict, List, Tuple +import re +from enum import IntEnum +from typing import Callable, Dict, List, Set, Tuple, cast import numpy as np from concrete.lang.dialects import fhe, fhelinalg @@ -34,6 +36,20 @@ from .utils import construct_deduplicated_tables # pylint: enable=no-member,no-name-in-module +class Comparison(IntEnum): + """ + Comparison enum, to generalize conversion of comparison operators. + + Because comparison result has 3 possibilities, Comparison enum is 2 bits. + """ + + EQUAL = 0b00 + LESS = 0b01 + GREATER = 0b10 + + UNUSED = 0b11 + + class NodeConverter: """ NodeConverter class, to convert computation graph nodes to their MLIR equivalent. @@ -125,11 +141,12 @@ class NodeConverter: self.all_of_the_inputs_are_tensors = True self.one_of_the_inputs_is_a_tensor = False - for inp in node.inputs: - if not inp.is_encrypted: + for pred in graph.ordered_preds_of(node): + if not pred.output.is_encrypted: self.all_of_the_inputs_are_encrypted = False - if inp.is_scalar: + shape = pred.properties.get("original_shape", pred.output.shape) + if shape == (): self.all_of_the_inputs_are_tensors = False else: self.one_of_the_inputs_is_a_tensor = True @@ -156,20 +173,31 @@ class NodeConverter: "add": self._convert_add, "array": self._convert_array, "assign.static": self._convert_static_assignment, + "bitwise_and": self._convert_bitwise_and, + "bitwise_or": self._convert_bitwise_or, + "bitwise_xor": self._convert_bitwise_xor, "broadcast_to": self._convert_broadcast_to, "concatenate": self._convert_concat, "conv1d": self._convert_conv1d, "conv2d": self._convert_conv2d, "conv3d": self._convert_conv3d, "dot": self._convert_dot, + "equal": self._convert_equal, "expand_dims": self._convert_reshape, + "greater": self._convert_greater, + "greater_equal": self._convert_greater_equal, "index.static": self._convert_static_indexing, + "left_shift": self._convert_left_shift, + "less": self._convert_less, + "less_equal": self._convert_less_equal, "matmul": self._convert_matmul, "maxpool": self._convert_maxpool, "multiply": self._convert_mul, "negative": self._convert_neg, + "not_equal": self._convert_not_equal, "ones": self._convert_ones, "reshape": self._convert_reshape, + "right_shift": self._convert_right_shift, "squeeze": self._convert_squeeze, "subtract": self._convert_sub, "sum": self._convert_sum, @@ -183,7 +211,7 @@ class NodeConverter: assert_that(self.node.converted_to_table_lookup) return self._convert_tlu() - # pylint: disable=no-self-use + # pylint: disable=no-self-use,too-many-branches,too-many-locals,too-many-statements def _convert_add(self) -> OpResult: """ @@ -277,6 +305,39 @@ class NodeConverter: self.from_elements_operations[placeholder_result] = processed_preds return placeholder_result + def _convert_bitwise_and(self) -> OpResult: + """ + Convert "bitwise_and" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_bitwise(lambda x, y: x & y) + + def _convert_bitwise_or(self) -> OpResult: + """ + Convert "bitwise_or" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_bitwise(lambda x, y: x | y) + + def _convert_bitwise_xor(self) -> OpResult: + """ + Convert "bitwise_xor" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_bitwise(lambda x, y: x ^ y) + def _convert_concat(self) -> OpResult: """ Convert "concatenate" node to its corresponding MLIR representation. @@ -462,6 +523,76 @@ class NodeConverter: return result + def _convert_equal(self) -> OpResult: + """ + Convert "equal" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_equality(equals=True) + + def _convert_greater(self) -> OpResult: + """ + Convert "greater" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_compare(invert_operands=True, accept={Comparison.LESS}) + + def _convert_greater_equal(self) -> OpResult: + """ + Convert "greater_equal" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_compare( + invert_operands=True, accept={Comparison.LESS, Comparison.EQUAL} + ) + + def _convert_left_shift(self) -> OpResult: + """ + Convert "left_shift" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_shift(orientation="left") + + def _convert_less(self) -> OpResult: + """ + Convert "less" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_compare(invert_operands=False, accept={Comparison.LESS}) + + def _convert_less_equal(self) -> OpResult: + """ + Convert "less_equal" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_compare( + invert_operands=False, accept={Comparison.LESS, Comparison.EQUAL} + ) + def _convert_matmul(self) -> OpResult: """Convert a MatMul node to its corresponding MLIR representation. @@ -537,6 +668,17 @@ class NodeConverter: return result + def _convert_not_equal(self) -> OpResult: + """ + Convert "not_equal" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_equality(equals=False) + def _convert_ones(self) -> OpResult: """ Convert "ones" node to its corresponding MLIR representation. @@ -699,6 +841,17 @@ class NodeConverter: ), ).result + def _convert_right_shift(self) -> OpResult: + """ + Convert "right_shift" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + return self._convert_shift(orientation="right") + def _convert_static_assignment(self) -> OpResult: """ Convert "assign.static" node to its corresponding MLIR representation. @@ -1094,7 +1247,19 @@ class NodeConverter: return result - def _create_constant(self, mlir_type: Type, mlir_attribute: Attribute): + def _create_add(self, resulting_type, a, b) -> OpResult: + if self.one_of_the_inputs_is_a_tensor: + return fhelinalg.AddEintOp(resulting_type, a, b).result + + return fhe.AddEintOp(resulting_type, a, b).result + + def _create_add_clear(self, resulting_type, a, b) -> OpResult: + if self.one_of_the_inputs_is_a_tensor: + return fhelinalg.AddEintIntOp(resulting_type, a, b).result + + return fhe.AddEintIntOp(resulting_type, a, b).result + + def _create_constant(self, mlir_type: Type, mlir_attribute: Attribute) -> OpResult: result = self.constant_cache.get((mlir_type, mlir_attribute)) if result is None: # ConstantOp is being decorated, and the init function is supposed to take more @@ -1105,4 +1270,597 @@ class NodeConverter: self.constant_cache[(mlir_type, mlir_attribute)] = result return result - # pylint: enable=no-self-use + def _create_constant_integer(self, bit_width, x) -> OpResult: + constant_value = Value( + Integer(is_signed=True, bit_width=bit_width + 1), + shape=(1,) if self.one_of_the_inputs_is_a_tensor else (), + is_encrypted=False, + ) + constant_type = NodeConverter.value_to_mlir_type(self.ctx, constant_value) + + if self.one_of_the_inputs_is_a_tensor: + return self._create_constant( + constant_type, + Attribute.parse(f"dense<{[int(x)]}> : {constant_type}"), + ) + + return self._create_constant(constant_type, IntegerAttr.get(constant_type, x)) + + def _create_mul_clear(self, resulting_type, a, b) -> OpResult: + if self.one_of_the_inputs_is_a_tensor: + return fhelinalg.MulEintIntOp(resulting_type, a, b).result + + return fhe.MulEintIntOp(resulting_type, a, b).result + + def _create_sub(self, resulting_type, a, b) -> OpResult: + if self.one_of_the_inputs_is_a_tensor: + return fhelinalg.SubEintOp(resulting_type, a, b).result + + return fhe.SubEintOp(resulting_type, a, b).result + + def _create_tlu(self, resulting_type, pred, lut_values) -> OpResult: + resulting_type_str = str(resulting_type) + bit_width_search = re.search(r"FHE\.eint<([0-9]+)>", resulting_type_str) + + assert bit_width_search is not None + bit_width_str = bit_width_search.group(1) + + bit_width = int(bit_width_str) + lut_values += [0] * ((2**bit_width) - len(lut_values)) + + lut_element_type = IntegerType.get_signless(64, context=self.ctx) + lut_type = RankedTensorType.get((len(lut_values),), lut_element_type) + lut_attr = Attribute.parse(f"dense<{str(lut_values)}> : {lut_type}") + lut = self._create_constant(resulting_type, lut_attr).result + + if self.one_of_the_inputs_is_a_tensor: + return fhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result + + return fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result + + def _convert_bitwise(self, operation: Callable) -> OpResult: + """ + Convert `bitwise_{and,or,xor}` node to their corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + if not self.all_of_the_inputs_are_encrypted: + return self._convert_tlu() + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + + x = self.preds[0] + y = self.preds[1] + + x_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.inputs[0]) + y_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.inputs[1]) + + assert isinstance(self.node.output.dtype, Integer) + bit_width = self.node.output.dtype.bit_width + + chunk_size = int(np.floor(bit_width / 2)) + mask = (2**chunk_size) - 1 + + original_bit_width = max( + pred_node.properties["original_bit_width"] + for pred_node in self.graph.ordered_preds_of(self.node) + ) + + chunks = [] + for offset in range(0, original_bit_width, chunk_size): + x_lut = [((x >> offset) & mask) << chunk_size for x in range(2**bit_width)] + y_lut = [(y >> offset) & mask for y in range(2**bit_width)] + + x_chunk = self._create_tlu(x_type, x, x_lut) + y_chunk = self._create_tlu(y_type, y, y_lut) + + packed_x_and_y_chunks = self._create_add(resulting_type, x_chunk, y_chunk) + result_chunk = self._create_tlu( + resulting_type, + packed_x_and_y_chunks, + [ + operation(x, y) << offset + for x in range(2**chunk_size) + for y in range(2**chunk_size) + ], + ) + + chunks.append(result_chunk) + + # add all chunks together in a tree to maximize dataflow parallelization + + # c1 c2 c3 c4 + # \/ \/ + # s2 s1 + # \ / + # \ / + # s3 + + while len(chunks) > 1: + a = chunks.pop() + b = chunks.pop() + + result = self._create_add(resulting_type, a, b) + chunks.insert(0, result) + + return chunks[0] + + def _convert_compare(self, accept: Set[Comparison], invert_operands: bool = False) -> OpResult: + if not self.all_of_the_inputs_are_encrypted: + return self._convert_tlu() + + inputs = self.node.inputs + preds = self.preds + + if invert_operands: + inputs = inputs[::-1] + preds = preds[::-1] + + x = preds[0] + y = preds[1] + + x_is_signed = cast(Integer, inputs[0].dtype).is_signed + y_is_signed = cast(Integer, inputs[1].dtype).is_signed + + x_is_unsigned = not x_is_signed + y_is_unsigned = not y_is_signed + + pred_nodes = self.graph.ordered_preds_of(self.node) + if invert_operands: + pred_nodes = list(reversed(pred_nodes)) + + x_bit_width, y_bit_width = [node.properties["original_bit_width"] for node in pred_nodes] + comparison_bit_width = max(x_bit_width, y_bit_width) + + x_dtype = Integer(is_signed=x_is_signed, bit_width=x_bit_width) + y_dtype = Integer(is_signed=y_is_signed, bit_width=y_bit_width) + + x_minus_y_min = x_dtype.min() - y_dtype.max() + x_minus_y_max = x_dtype.max() - y_dtype.min() + + x_minus_y_range = [x_minus_y_min, x_minus_y_max] + x_minus_y_dtype = Integer.that_can_represent(x_minus_y_range) + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + + bit_width = cast(Integer, self.node.output.dtype).bit_width + signed_offset = 2 ** (bit_width - 1) + + sanitizer = self._create_constant_integer(bit_width, signed_offset) + if x_minus_y_dtype.bit_width <= bit_width: + x_minus_y = self._create_sub(resulting_type, x, y) + sanitized_x_minus_y = self._create_add_clear(resulting_type, x_minus_y, sanitizer) + + accept_less = int(Comparison.LESS in accept) + accept_equal = int(Comparison.EQUAL in accept) + accept_greater = int(Comparison.GREATER in accept) + + all_cells = 2**bit_width + + less_cells = 2 ** (bit_width - 1) + equal_cells = 1 + greater_cells = less_cells - equal_cells + + assert less_cells + equal_cells + greater_cells == all_cells + + table = ( + [accept_less] * less_cells + + [accept_equal] * equal_cells + + [accept_greater] * greater_cells + ) + return self._create_tlu(resulting_type, sanitized_x_minus_y, table) + + # Comparison between signed and unsigned is tricky. + # To deal with them, we add -min of the signed number to both operands + # such that they are both positive. To avoid overflowing + # the unsigned operand this addition is done "virtually" + # while constructing one of the luts. + + # A flag ("is_unsigned_greater_than_half") is emitted in MLIR to keep track + # if the unsigned operand was greater than the max signed number as it + # is needed to determine the result of the comparison. + + # Exemple: to compare x and y where x is an int3 and y and uint3, when y + # is greater than 4 we are sure than x will be less than x. + + if inputs[0].shape != self.node.output.shape: + x = self._create_add(resulting_type, x, self._convert_zeros()) + if inputs[1].shape != self.node.output.shape: + y = self._create_add(resulting_type, y, self._convert_zeros()) + + offset_x_by = 0 + offset_y_by = 0 + + if x_is_signed or y_is_signed: + if x_is_signed: + x = self._create_add_clear(resulting_type, x, sanitizer) + else: + offset_x_by = signed_offset + + if y_is_signed: + y = self._create_add_clear(resulting_type, y, sanitizer) + else: + offset_y_by = signed_offset + + def compare(x, y): + if x < y: + return Comparison.LESS + + if x > y: + return Comparison.GREATER + + return Comparison.EQUAL + + chunk_size = int(np.floor(bit_width / 2)) + carries = self._pack_to_chunk_groups_and_map( + resulting_type, + comparison_bit_width, + chunk_size, + x, + y, + lambda i, x, y: compare(x, y) << (min(i, 1) * 2), + x_offset=offset_x_by, + y_offset=offset_y_by, + ) + + # This is the reduction step -- we have an array where the entry i is the + # result of comparing the chunks of x and y at position i. + + all_comparisons = [Comparison.EQUAL, Comparison.LESS, Comparison.GREATER, Comparison.UNUSED] + pick_first_not_equal_lut = [ + int( + current_comparison + if previous_comparison == Comparison.EQUAL + else previous_comparison + ) + for current_comparison in all_comparisons + for previous_comparison in all_comparisons + ] + + carry = carries[0] + for next_carry in carries[1:]: + combined_carries = self._create_add(resulting_type, next_carry, carry) + carry = self._create_tlu(resulting_type, combined_carries, pick_first_not_equal_lut) + + if x_is_signed != y_is_signed: + carry_bit_width = 2 + is_less_mask = int(Comparison.LESS) + + is_unsigned_greater_than_half = self._create_tlu( + resulting_type, + x if x_is_unsigned else y, + [int(value >= signed_offset) << carry_bit_width for value in range(2**bit_width)], + ) + packed_carry_and_is_unsigned_greater_than_half = self._create_add( + resulting_type, + is_unsigned_greater_than_half, + carry, + ) + + # this function is actually converting either + # - lhs < rhs + # - lhs <= rhs + + # in the implementation, we call + # - x = lhs + # - y = rhs + + # so if y is unsigned and greater than half + # - y is definitely bigger than x + # - is_unsigned_greater_than_half == 1 + # - result == (lhs < rhs) == (x < y) == 1 + + # so if x is unsigned and greater than half + # - x is definitely bigger than y + # - is_unsigned_greater_than_half == 1 + # - result == (lhs < rhs) == (x < y) == 0 + + if y_is_unsigned: + result_table = [ + 1 if (i >> carry_bit_width) else (i & is_less_mask) for i in range(2**3) + ] + else: + result_table = [ + 0 if (i >> carry_bit_width) else (i & is_less_mask) for i in range(2**3) + ] + + result = self._create_tlu( + resulting_type, + packed_carry_and_is_unsigned_greater_than_half, + result_table, + ) + else: + boolean_result_lut = [int(comparison in accept) for comparison in all_comparisons] + result = self._create_tlu(resulting_type, carry, boolean_result_lut) + + return result + + def _convert_equality(self, equals: bool) -> OpResult: + if not self.all_of_the_inputs_are_encrypted: + return self._convert_tlu() + + x = self.preds[0] + y = self.preds[1] + + x_is_signed = cast(Integer, self.node.inputs[0].dtype).is_signed + y_is_signed = cast(Integer, self.node.inputs[1].dtype).is_signed + + x_is_unsigned = not x_is_signed + y_is_unsigned = not y_is_signed + + x_bit_width, y_bit_width = [ + node.properties["original_bit_width"] for node in self.graph.ordered_preds_of(self.node) + ] + comparison_bit_width = max(x_bit_width, y_bit_width) + + x_dtype = Integer(is_signed=x_is_signed, bit_width=x_bit_width) + y_dtype = Integer(is_signed=y_is_signed, bit_width=y_bit_width) + + x_minus_y_min = x_dtype.min() - y_dtype.max() + x_minus_y_max = x_dtype.max() - y_dtype.min() + + x_minus_y_range = [x_minus_y_min, x_minus_y_max] + x_minus_y_dtype = Integer.that_can_represent(x_minus_y_range) + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + + bit_width = cast(Integer, self.node.output.dtype).bit_width + signed_offset = 2 ** (bit_width - 1) + + sanitizer = self._create_constant_integer(bit_width, signed_offset) + if x_minus_y_dtype.bit_width <= bit_width: + x_minus_y = self._create_sub(resulting_type, x, y) + sanitized_x_minus_y = self._create_add_clear(resulting_type, x_minus_y, sanitizer) + + zero_position = 2 ** (bit_width - 1) + if equals: + operation_lut = [int(i == zero_position) for i in range(2**bit_width)] + else: + operation_lut = [int(i != zero_position) for i in range(2**bit_width)] + + return self._create_tlu(resulting_type, sanitized_x_minus_y, operation_lut) + + chunk_size = int(np.floor(bit_width / 2)) + number_of_chunks = int(np.ceil(bit_width / chunk_size)) + + if x_is_signed != y_is_signed: + number_of_chunks += 1 + + if self.node.inputs[0].shape != self.node.output.shape: + x = self._create_add(resulting_type, x, self._convert_zeros()) + if self.node.inputs[1].shape != self.node.output.shape: + y = self._create_add(resulting_type, y, self._convert_zeros()) + + greater_than_half_lut = [ + int(i >= signed_offset) << (number_of_chunks - 1) for i in range(2**bit_width) + ] + if x_is_unsigned and y_is_signed: + is_unsigned_greater_than_half = self._create_tlu( + resulting_type, + x, + greater_than_half_lut, + ) + elif x_is_signed and y_is_unsigned: + is_unsigned_greater_than_half = self._create_tlu( + resulting_type, + y, + greater_than_half_lut, + ) + else: + is_unsigned_greater_than_half = None + + offset_x_by = 0 + offset_y_by = 0 + + if x_is_signed or y_is_signed: + if x_is_signed: + x = self._create_add_clear(resulting_type, x, sanitizer) + else: + offset_x_by = signed_offset + + if y_is_signed: + y = self._create_add_clear(resulting_type, y, sanitizer) + else: + offset_y_by = signed_offset + + carries = self._pack_to_chunk_groups_and_map( + resulting_type, + comparison_bit_width, + chunk_size, + x, + y, + lambda _, x, y: int(x != y), + x_offset=offset_x_by, + y_offset=offset_y_by, + ) + + if is_unsigned_greater_than_half: + carries.append(is_unsigned_greater_than_half) + + while len(carries) > 1: + a = carries.pop() + b = carries.pop() + + result = self._create_add(resulting_type, a, b) + carries.insert(0, result) + + carry = carries[0] + + return self._create_tlu( + resulting_type, + carry, + [int(i == 0 if equals else i != 0) for i in range(2**bit_width)], + ) + + def _convert_shift(self, orientation: str) -> OpResult: + if not self.all_of_the_inputs_are_encrypted: + return self._convert_tlu() + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + + x = self.preds[0] + b = self.preds[1] + + b_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.inputs[1]) + + assert isinstance(self.node.output.dtype, Integer) + bit_width = self.node.output.dtype.bit_width + + pred_nodes = self.graph.ordered_preds_of(self.node) + + x_original_bit_width = pred_nodes[0].properties["original_bit_width"] + b_original_bit_width = pred_nodes[1].properties["original_bit_width"] + + if self.node.inputs[0].shape != self.node.output.shape: + x = self._create_add(resulting_type, x, self._convert_zeros()) + + if x_original_bit_width + b_original_bit_width <= bit_width: + shift_multiplier = self._create_constant_integer(bit_width, 2**b_original_bit_width) + shifted_x = self._create_mul_clear(resulting_type, x, shift_multiplier) + packed_x_and_b = self._create_add(resulting_type, shifted_x, b) + return self._create_tlu( + resulting_type, + packed_x_and_b, + [ + (x << b) if orientation == "left" else (x >> b) + for x in range(2**x_original_bit_width) + for b in range(2**b_original_bit_width) + ], + ) + + # Left_shifts of x << b can be done as follows: + # - left shift of x by 8 if b & 0b1000 > 0 + # - left shift of x by 4 if b & 0b0100 > 0 + # - left shift of x by 2 if b & 0b0010 > 0 + # - left shift of x by 1 if b & 0b0001 > 0 + + # Encoding this condition is non-trivial -- however, + # it can be done using the following trick: + + # y = (b & 0b1000 > 0) * ((x << 8) - x) + x + + # When b & 0b1000, then: + # y = 1 * ((x << 8) - x) + x = (x << 8) - x + x = x << 8 + + # When b & 0b1000 == 0 then: + # y = 0 * ((x << 8) - x) + x = x + + # The same trick can be used for right shift but with: + # y = x - (b & 0b1000 > 0) * (x - (x >> 8)) + + original_bit_width = self.node.properties["original_bit_width"] + chunk_size = min(original_bit_width, bit_width - 1) + + for i in reversed(range(b_original_bit_width)): + to_check = 2**i + + should_shift = self._create_tlu( + b_type, + b, + [int((b & to_check) > 0) for b in range(2**bit_width)], + ) + shifted_x = self._create_tlu( + resulting_type, + x, + ( + [(x << to_check) - x for x in range(2**bit_width)] + if orientation == "left" + else [x - (x >> to_check) for x in range(2**bit_width)] + ), + ) + + chunks = [] + for offset in range(0, original_bit_width, chunk_size): + bits_to_process = min(chunk_size, original_bit_width - offset) + right_shift_by = original_bit_width - offset - bits_to_process + mask = (2**bits_to_process) - 1 + + chunk_x = self._create_tlu( + resulting_type, + shifted_x, + [(((x >> right_shift_by) & mask) << 1) for x in range(2**bit_width)], + ) + packed_chunk_x_and_should_shift = self._create_add( + resulting_type, chunk_x, should_shift + ) + + chunk = self._create_tlu( + resulting_type, + packed_chunk_x_and_should_shift, + [ + (x << right_shift_by) if b else 0 + for x in range(2**chunk_size) + for b in [0, 1] + ], + ) + chunks.append(chunk) + + difference = chunks[0] + for chunk in chunks[1:]: + difference = self._create_add(resulting_type, difference, chunk) + + x = ( + self._create_add(resulting_type, difference, x) + if orientation == "left" + else self._create_sub(resulting_type, x, difference) + ) + + return x + + def _pack_to_chunk_groups_and_map( + self, + resulting_type: Type, + bit_width: int, + chunk_size: int, + x: OpResult, + y: OpResult, + mapper: Callable, + x_offset: int = 0, + y_offset: int = 0, + ) -> List[OpResult]: + """ + Split x and y into chunks, pack the chunks and map it to another integer. + + Split x and y into chunks of size `chunk_size`. + Pack those chunks into an integer and apply `mapper` function to it. + Combine those results into a list and return it. + + If `x_offset` (resp. `y_offset`) is provided, execute the function + for `x + offset_x_by` (resp. `y + offset_y_by`) instead of `x` and `y`. + """ + + result = [] + for chunk_index, offset in enumerate(range(0, bit_width, chunk_size)): + bits_to_process = min(chunk_size, bit_width - offset) + right_shift_by = bit_width - offset - bits_to_process + mask = (2**bits_to_process) - 1 + + chunk_x = self._create_tlu( + resulting_type, + x, + [ + ((((x + x_offset) >> right_shift_by) & mask) << bits_to_process) + for x in range(2**bit_width) + ], + ) + chunk_y = self._create_tlu( + resulting_type, + y, + [((y + y_offset) >> right_shift_by) & mask for y in range(2**bit_width)], + ) + + packed_chunks = self._create_add(resulting_type, chunk_x, chunk_y) + mapped_chunks = self._create_tlu( + resulting_type, + packed_chunks, + [mapper(chunk_index, x, y) for x in range(mask + 1) for y in range(mask + 1)], + ) + + result.append(mapped_chunks) + + return result + + # pylint: enable=no-self-use,too-many-branches,too-many-locals,too-many-statements diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index 35c00df52..c8f083047 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -363,6 +363,26 @@ class Node: True if the node is converted to a table lookup, False otherwise """ + if ( + all(value.is_encrypted for value in self.inputs) + and self.operation == Operation.Generic + and self.properties["name"] + in [ + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "equal", + "greater", + "greater_equal", + "left_shift", + "less", + "less_equal", + "not_equal", + "right_shift", + ] + ): + return False + return self.operation == Operation.Generic and self.properties["name"] not in [ "add", "array", diff --git a/pyproject.toml b/pyproject.toml index 9b78276fe..c078a1db3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,7 @@ select = [ "PLC", "PLE", "PLR", "PLW", "RUF" ] ignore = [ - "A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100", + "A", "D", "FBT", "T20", "ANN", "N806", "ARG001", "S101", "BLE001", "RUF100", "ERA001", "RET504", "TID252", "PD011", "I001", "UP015", "C901", "A001", "SIM118", "PGH003" ] diff --git a/tests/conftest.py b/tests/conftest.py index d1a251a63..37456403c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -244,14 +244,24 @@ class Helpers: if not isinstance(sample, list): sample = [sample] - for i in range(retries): - expected = function(*sample) - actual = circuit.encrypt_run_decrypt(*sample) + def sanitize(values): + if not isinstance(values, tuple): + values = (values,) - if not isinstance(expected, tuple): - expected = (expected,) - if not isinstance(actual, tuple): - actual = (actual,) + result = [] + for value in values: + if isinstance(value, (bool, np.bool_)): + value = int(value) + elif isinstance(value, np.ndarray) and value.dtype == np.bool_: + value = value.astype(np.int64) + + result.append(value) + + return tuple(result) + + for i in range(retries): + expected = sanitize(function(*sample)) + actual = sanitize(circuit.encrypt_run_decrypt(*sample)) if all(np.array_equal(e, a) for e, a in zip(expected, actual)): break diff --git a/tests/execution/test_bitwise.py b/tests/execution/test_bitwise.py new file mode 100644 index 000000000..ad13984aa --- /dev/null +++ b/tests/execution/test_bitwise.py @@ -0,0 +1,62 @@ +""" +Tests of execution of bitwise operations. +""" + +import pytest + +import concrete.numpy as cnp + + +@pytest.mark.parametrize( + "function", + [ + pytest.param( + lambda x, y: x & y, + id="x & y", + ), + pytest.param( + lambda x, y: x | y, + id="x | y", + ), + pytest.param( + lambda x, y: x ^ y, + id="x ^ y", + ), + ], +) +@pytest.mark.parametrize( + "parameters", + [ + { + "x": {"range": [0, 255], "status": "encrypted"}, + "y": {"range": [0, 255], "status": "encrypted"}, + }, + { + "x": {"range": [0, 7], "status": "encrypted"}, + "y": {"range": [0, 7], "status": "encrypted", "shape": (3,)}, + }, + { + "x": {"range": [0, 7], "status": "encrypted", "shape": (3,)}, + "y": {"range": [0, 7], "status": "encrypted"}, + }, + { + "x": {"range": [0, 7], "status": "encrypted", "shape": (3,)}, + "y": {"range": [0, 7], "status": "encrypted", "shape": (3,)}, + }, + ], +) +def test_bitwise(function, parameters, helpers): + """ + Test bitwise operations between encrypted integers. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, parameter_encryption_statuses) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset, configuration) + + sample = helpers.generate_sample(parameters) + helpers.check_execution(circuit, function, sample) diff --git a/tests/execution/test_comparison.py b/tests/execution/test_comparison.py new file mode 100644 index 000000000..34e892357 --- /dev/null +++ b/tests/execution/test_comparison.py @@ -0,0 +1,161 @@ +""" +Tests of execution of comparison operations. +""" + +import pytest + +import concrete.numpy as cnp + + +@pytest.mark.parametrize( + "function", + [ + pytest.param( + lambda x, y: x == y, + id="x == y", + ), + pytest.param( + lambda x, y: x != y, + id="x != y", + ), + pytest.param( + lambda x, y: x < y, + id="x < y", + ), + pytest.param( + lambda x, y: x <= y, + id="x <= y", + ), + pytest.param( + lambda x, y: x > y, + id="x > y", + ), + pytest.param( + lambda x, y: x >= y, + id="x >= y", + ), + ], +) +@pytest.mark.parametrize( + "parameters", + [ + { + "x": {"range": [0, 3], "status": "encrypted"}, + "y": {"range": [0, 3], "status": "encrypted"}, + }, + { + "x": {"range": [0, 255], "status": "encrypted"}, + "y": {"range": [0, 255], "status": "encrypted"}, + }, + { + "x": {"range": [-128, 127], "status": "encrypted"}, + "y": {"range": [-128, 127], "status": "encrypted"}, + }, + { + "x": {"range": [-128, 127], "status": "encrypted"}, + "y": {"range": [0, 255], "status": "encrypted"}, + }, + { + "x": {"range": [0, 255], "status": "encrypted"}, + "y": {"range": [-128, 127], "status": "encrypted"}, + }, + { + "x": {"range": [-8, 7], "status": "encrypted"}, + "y": {"range": [-8, 7], "status": "encrypted", "shape": (2,)}, + }, + { + "x": {"range": [-8, 7], "status": "encrypted", "shape": (2,)}, + "y": {"range": [-8, 7], "status": "encrypted"}, + }, + { + "x": {"range": [-8, 7], "status": "encrypted", "shape": (2,)}, + "y": {"range": [-8, 7], "status": "encrypted", "shape": (2,)}, + }, + ], +) +def test_comparison(function, parameters, helpers): + """ + Test comparison operations between encrypted integers. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, parameter_encryption_statuses) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset, configuration) + + sample = helpers.generate_sample(parameters) + helpers.check_execution(circuit, function, sample) + + +@pytest.mark.parametrize( + "function", + [ + pytest.param( + lambda x, y: (x == y) + 200, + id="(x == y) + 200", + ), + pytest.param( + lambda x, y: (x != y) + 200, + id="(x != y) + 200", + ), + pytest.param( + lambda x, y: (x < y) + 200, + id="(x < y) + 200", + ), + pytest.param( + lambda x, y: (x <= y) + 200, + id="(x <= y) + 200", + ), + pytest.param( + lambda x, y: (x > y) + 200, + id="(x > y) + 200", + ), + pytest.param( + lambda x, y: (x >= y) + 200, + id="(x >= y) + 200", + ), + ], +) +@pytest.mark.parametrize( + "parameters", + [ + { + "x": {"range": [0, 15], "status": "encrypted"}, + "y": {"range": [0, 15], "status": "encrypted"}, + }, + { + "x": {"range": [-8, 7], "status": "encrypted"}, + "y": {"range": [-8, 7], "status": "encrypted"}, + }, + { + "x": {"range": [0, 15], "status": "encrypted"}, + "y": {"range": [0, 15], "status": "encrypted", "shape": (2,)}, + }, + { + "x": {"range": [-8, 7], "status": "encrypted", "shape": (2,)}, + "y": {"range": [-8, 7], "status": "encrypted"}, + }, + { + "x": {"range": [-10, 10], "status": "encrypted", "shape": (2,)}, + "y": {"range": [-10, 10], "status": "encrypted", "shape": (2,)}, + }, + ], +) +def test_optimized_comparison(function, parameters, helpers): + """ + Test comparison operations between encrypted integers with a single TLU. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, parameter_encryption_statuses) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset, configuration) + + sample = helpers.generate_sample(parameters) + helpers.check_execution(circuit, function, sample) diff --git a/tests/execution/test_shift.py b/tests/execution/test_shift.py new file mode 100644 index 000000000..f93fbf0e1 --- /dev/null +++ b/tests/execution/test_shift.py @@ -0,0 +1,178 @@ +""" +Tests of execution of shift operations. +""" + +import pytest + +import concrete.numpy as cnp + + +@pytest.mark.parametrize( + "function", + [ + pytest.param( + lambda x, y: x << y, + id="x << y", + ), + ], +) +@pytest.mark.parametrize( + "parameters", + [ + { + "x": {"range": [0, 1], "status": "encrypted"}, + "y": {"range": [0, 7], "status": "encrypted"}, + }, + { + "x": {"range": [0, 3], "status": "encrypted"}, + "y": {"range": [0, 3], "status": "encrypted", "shape": (2,)}, + }, + { + "x": {"range": [0, 3], "status": "encrypted", "shape": (2,)}, + "y": {"range": [0, 3], "status": "encrypted"}, + }, + { + "x": {"range": [0, 3], "status": "encrypted", "shape": (2,)}, + "y": {"range": [0, 3], "status": "encrypted", "shape": (2,)}, + }, + ], +) +def test_left_shift(function, parameters, helpers): + """ + Test left shift between encrypted integers. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, parameter_encryption_statuses) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset, configuration) + + sample = helpers.generate_sample(parameters) + helpers.check_execution(circuit, function, sample) + + +@pytest.mark.parametrize( + "function", + [ + pytest.param( + lambda x, y: x >> y, + id="x >> y", + ), + ], +) +@pytest.mark.parametrize( + "parameters", + [ + { + "x": {"range": [0, 1 << 7], "status": "encrypted"}, + "y": {"range": [0, 7], "status": "encrypted"}, + }, + { + "x": {"range": [0, 1 << 4], "status": "encrypted"}, + "y": {"range": [0, 3], "status": "encrypted", "shape": (2,)}, + }, + { + "x": {"range": [0, 1 << 4], "status": "encrypted", "shape": (2,)}, + "y": {"range": [0, 3], "status": "encrypted"}, + }, + { + "x": {"range": [0, 1 << 4], "status": "encrypted", "shape": (2,)}, + "y": {"range": [0, 3], "status": "encrypted", "shape": (2,)}, + }, + ], +) +def test_right_shift(function, parameters, helpers): + """ + Test right shift between encrypted integers. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, parameter_encryption_statuses) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset, configuration) + + sample = helpers.generate_sample(parameters) + helpers.check_execution(circuit, function, sample) + + +@pytest.mark.parametrize( + "function", + [ + pytest.param( + lambda x, y: x << y, + id="x << y", + ), + ], +) +@pytest.mark.parametrize( + "parameters", + [ + { + "x": {"range": [0, 1], "status": "encrypted"}, + "y": {"range": [0, 7], "status": "encrypted"}, + }, + ], +) +def test_left_shift_coverage(function, parameters, helpers): + """ + Test left shift between encrypted integers all cases. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, parameter_encryption_statuses) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset, configuration) + + for i in range(2): + for j in range(8): + helpers.check_execution(circuit, function, [i, j]) + + +@pytest.mark.parametrize( + "function", + [ + pytest.param( + lambda x, y: x >> y, + id="x >> y", + ), + ], +) +@pytest.mark.parametrize( + "parameters", + [ + { + "x": {"range": [0, 1 << 7], "status": "encrypted"}, + "y": {"range": [0, 7], "status": "encrypted"}, + }, + ], +) +def test_right_shift_coverage(function, parameters, helpers): + """ + Test right shift between encrypted integers all cases. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, parameter_encryption_statuses) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset, configuration) + + helpers.check_execution(circuit, function, [0b11, 0]) + helpers.check_execution(circuit, function, [0b11, 1]) + helpers.check_execution(circuit, function, [0b110, 2]) + helpers.check_execution(circuit, function, [0b1100, 3]) + helpers.check_execution(circuit, function, [0b11000, 4]) + helpers.check_execution(circuit, function, [0b110000, 5]) + helpers.check_execution(circuit, function, [0b110000, 6]) + helpers.check_execution(circuit, function, [0b1100000, 7]) diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 1fbfac691..72c8da990 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -400,6 +400,40 @@ Subgraphs: %5 = astype(%4, dtype=int_) # EncryptedScalar return %5 + """, # noqa: E501 + ), + pytest.param( + lambda x, y: x << y, + {"x": "encrypted", "y": "encrypted"}, + [(-1, 1), (-2, 3)], + RuntimeError, + """ + +Function you are trying to compile cannot be converted to MLIR + +%0 = x # EncryptedScalar ∈ [-2, -1] +%1 = y # EncryptedScalar ∈ [1, 3] +%2 = left_shift(%0, %1) # EncryptedScalar ∈ [-16, -2] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned bitwise operations are supported +return %2 + + """, # noqa: E501 + ), + pytest.param( + lambda x, y: x << y, + {"x": "encrypted", "y": "encrypted"}, + [(1, 20), (2, 10)], + RuntimeError, + """ + +Function you are trying to compile cannot be converted to MLIR + +%0 = x # EncryptedScalar ∈ [1, 2] +%1 = y # EncryptedScalar ∈ [10, 20] +%2 = left_shift(%0, %1) # EncryptedScalar ∈ [2048, 1048576] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 4-bit shifts are supported +return %2 + """, # noqa: E501 ), ],