diff --git a/concrete/numpy/mlir/__init__.py b/concrete/numpy/mlir/__init__.py new file mode 100644 index 000000000..262d3d042 --- /dev/null +++ b/concrete/numpy/mlir/__init__.py @@ -0,0 +1,6 @@ +""" +Declaration of `concrete.numpy.mlir` namespace. +""" + +from .graph_converter import GraphConverter +from .node_converter import NodeConverter diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py new file mode 100644 index 000000000..70c75877d --- /dev/null +++ b/concrete/numpy/mlir/graph_converter.py @@ -0,0 +1,380 @@ +""" +Declaration of `GraphConverter` class. +""" + +# pylint: disable=no-member,no-name-in-module + +from copy import deepcopy +from typing import Dict, List, Optional, cast + +import concrete.lang as concretelang +import networkx as nx +import numpy as np +from mlir.dialects import builtin +from mlir.ir import Context, InsertionPoint, Location, Module + +from ..dtypes import Integer, SignedInteger +from ..internal.utils import assert_that +from ..representation import Graph, Node, Operation +from ..values import ClearScalar +from .node_converter import NodeConverter +from .utils import MAXIMUM_BIT_WIDTH + +# pylint: enable=no-member,no-name-in-module + + +class GraphConverter: + """ + GraphConverter class, to convert computation graphs to their MLIR equivalent. + """ + + @staticmethod + def _check_node_convertibility(graph: Graph, node: Node) -> Optional[str]: + """ + Check node convertibility to MLIR. + + Args: + graph (Graph): + computation graph of the node + + node (Node): + node to be checked + + Returns: + Optional[str]: + None if node is convertible to MLIR, the reason for inconvertibility otherwise + """ + + # pylint: disable=too-many-branches,too-many-return-statements + + inputs = node.inputs + output = node.output + + if node.operation == Operation.Constant: + assert_that(len(inputs) == 0) + if not isinstance(output.dtype, Integer): + return "only integer constants are supported" + + elif node.operation == Operation.Input: + assert_that(len(inputs) == 1) + assert_that(inputs[0] == output) + if not isinstance(output.dtype, Integer) or output.dtype.is_signed: + return "only unsigned integer inputs are supported" + + else: + assert_that(node.operation == Operation.Generic) + + if not isinstance(output.dtype, Integer): + return "only integer operations are supported" + + name = node.properties["name"] + + if name == "add": + assert_that(len(inputs) == 2) + + elif name == "concatenate": + if not all(input.is_encrypted for input in inputs): + return "only all encrypted concatenate is supported" + + elif name == "conv2d": + assert_that(len(inputs) == 2 or len(inputs) == 3) + if not (inputs[0].is_encrypted and inputs[1].is_clear): + return "only conv2d with encrypted input and clear weight is supported" + + elif name == "dot": + assert_that(len(inputs) == 2) + if inputs[0].is_encrypted and inputs[1].is_encrypted: + return "only dot product between encrypted and clear is supported" + + elif name == "index.static": + assert_that(len(inputs) == 1) + if not inputs[0].is_encrypted: + return "only encrypted indexing supported" + + elif name == "matmul": + assert_that(len(inputs) == 2) + if inputs[0].is_encrypted and inputs[1].is_encrypted: + return "only matrix multiplication between encrypted and clear is supported" + + elif name == "multiply": + assert_that(len(inputs) == 2) + if inputs[0].is_encrypted and inputs[1].is_encrypted: + return "only multiplication between encrypted and clear is supported" + + elif name == "negative": + assert_that(len(inputs) == 1) + if not inputs[0].is_encrypted: + return "only encrypted negation is supported" + + elif name == "reshape": + assert_that(len(inputs) == 1) + if not inputs[0].is_encrypted: + return "only encrypted reshape is supported" + + elif name == "subtract": + assert_that(len(inputs) == 2) + if not (inputs[0].is_clear and inputs[1].is_encrypted): + return "only subtraction of encrypted from clear is supported" + + elif name == "sum": + assert_that(len(inputs) == 1) + if not inputs[0].is_encrypted: + return "only encrypted sum is supported" + + else: + variable_input_indices = [ + idx + for idx, pred in enumerate(graph.ordered_preds_of(node)) + if not pred.operation == Operation.Constant + ] + if len(variable_input_indices) != 1: + return "only single input table lookups are supported" + + if all(input.is_clear for input in inputs): + return "one of the operands must be encrypted" + + return None + + # pylint: enable=too-many-branches,too-many-return-statements + + @staticmethod + def _check_graph_convertibility(graph: Graph): + """ + Check graph convertibility to MLIR. + + Args: + graph (Graph): + computation graph to be checked + + Raises: + RuntimeError: + if `graph` is not convertible to MLIR + """ + + offending_nodes = {} + + if len(graph.output_nodes) > 1: + offending_nodes.update( + { + node: ["only a single output is supported"] + for node in graph.output_nodes.values() + } + ) + + if len(offending_nodes) == 0: + for node in graph.graph.nodes: + if (reason := GraphConverter._check_node_convertibility(graph, node)) is not None: + offending_nodes[node] = [reason] + + if len(offending_nodes) != 0: + raise RuntimeError( + "Function you are trying to compile cannot be converted to MLIR\n\n" + + graph.format(highlighted_nodes=offending_nodes) + ) + + @staticmethod + def _update_bit_widths(graph: Graph): + """ + Update bit-widths in a computation graph to be convertible to MLIR. + + Args: + graph (Graph): + computation graph to be updated + """ + + offending_nodes: Dict[Node, List[str]] = {} + + max_bit_width = 0 + for node in graph.graph.nodes: + dtype = node.output.dtype + assert_that(isinstance(dtype, Integer)) + + current_node_bit_width = ( + dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width + ) + max_bit_width = max(max_bit_width, current_node_bit_width) + + if current_node_bit_width > MAXIMUM_BIT_WIDTH: + offending_nodes[node] = [ + f"only up to {MAXIMUM_BIT_WIDTH}-bit integers are supported" + ] + + if len(offending_nodes) != 0: + raise RuntimeError( + "Function you are trying to compile cannot be converted to MLIR:\n\n" + + graph.format(highlighted_nodes=offending_nodes) + ) + + for node in graph.graph.nodes: + for value in node.inputs + [node.output]: + dtype = value.dtype + assert_that(isinstance(dtype, Integer)) + dtype.bit_width = max_bit_width + 1 if value.is_clear else max_bit_width + + @staticmethod + def _offset_negative_lookup_table_inputs(graph: Graph): + """ + Offset negative table lookup inputs to be convertible to MLIR. + + Args: + graph (Graph): + computation graph to apply offset + """ + + # ugly hack to add an offset before entering a TLU + # if its variable input node has a signed output. + # this makes hardcoded assumptions about the way bit widths are handled in MLIR. + # this does not update the TLU input values to allow for proper table generation. + + nx_graph = graph.graph + for node in list(nx_graph.nodes): + if node.operation == Operation.Generic: + if not node.converted_to_table_lookup: + continue + + variable_input_index = -1 + + preds = graph.ordered_preds_of(node) + for index, pred in enumerate(preds): + if pred.operation != Operation.Constant: + variable_input_index = index + break + + variable_input_node = preds[variable_input_index] + + variable_input_value = variable_input_node.output + variable_input_dtype = variable_input_value.dtype + + assert_that(isinstance(variable_input_dtype, Integer)) + variable_input_dtype = cast(Integer, variable_input_dtype) + + if not variable_input_dtype.is_signed: + continue + + variable_input_bit_width = variable_input_dtype.bit_width + offset_constant_dtype = SignedInteger(variable_input_bit_width + 1) + + offset_constant = Node.constant(abs(variable_input_dtype.min())) + offset_constant.output.dtype = offset_constant_dtype + + add_offset = Node.generic( + "add", + [variable_input_value, ClearScalar(offset_constant_dtype)], + variable_input_value, + np.add, + ) + + nx_graph.remove_edge(variable_input_node, node) + + nx_graph.add_edge(variable_input_node, add_offset, input_idx=0) + nx_graph.add_edge(offset_constant, add_offset, input_idx=1) + + nx_graph.add_edge(add_offset, node, input_idx=variable_input_index) + + @staticmethod + def convert(graph: Graph) -> str: + """ + Convert a computation graph to its corresponding MLIR representation. + + Args: + graph (Graph): + computation graph to be converted + + Returns: + str: + textual MLIR representation corresponding to `graph` + """ + + graph = deepcopy(graph) + + GraphConverter._check_graph_convertibility(graph) + GraphConverter._update_bit_widths(graph) + GraphConverter._offset_negative_lookup_table_inputs(graph) + + # There are no tensor +*- scalar operations in the compiler + # But such operations are used commonly, so we need to support them + # So, we implemented some workarounds (pull request #970) + # Once we have native support, this workaround shall be removed (issue #837) + # (most changes in #970 shall be reverted) + + # { node1: "%arg0", node2: "%0", node3: "%1" } + nodes_to_mlir_names: Dict[Node, str] = {} + + # { "%arg0": "i5", "%0": "tensor<2x3x!FHE.eint<4>>" } + mlir_names_to_mlir_types: Dict[str, str] = {} + + # { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor + scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {} + + with Context() as ctx, Location.unknown(): + concretelang.register_dialects(ctx) + + module = Module.create() + with InsertionPoint(module.body): + parameters = [ + NodeConverter.value_to_mlir_type(ctx, input_node.output) + for input_node in graph.ordered_inputs() + ] + + @builtin.FuncOp.from_py_func(*parameters) + def main(*arg): + ir_to_mlir = {} + for arg_num, node in graph.input_nodes.items(): + ir_to_mlir[node] = arg[arg_num] + + mlir_name = f"%arg{arg_num}" + nodes_to_mlir_names[node] = mlir_name + mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num]) + + for node in nx.topological_sort(graph.graph): + if node.operation == Operation.Input: + continue + + preds = [ir_to_mlir[pred] for pred in graph.ordered_preds_of(node)] + node_converter = NodeConverter( + ctx, + graph, + node, + preds, + nodes_to_mlir_names, + mlir_names_to_mlir_types, + scalar_to_1d_tensor_conversion_hacks, + ) + ir_to_mlir[node] = node_converter.convert() + + results = (ir_to_mlir[output_node] for output_node in graph.ordered_outputs()) + return results + + module_lines_after_hacks_are_applied = [] + for line in str(module).split("\n"): + mlir_name = line.split("=")[0].strip() + if mlir_name not in scalar_to_1d_tensor_conversion_hacks: + module_lines_after_hacks_are_applied.append(line) + continue + + to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name] + for arg_name in to_be_replaced: + new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}" + mlir_type = mlir_names_to_mlir_types[arg_name] + + hack_line = ( + f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>" + ) + module_lines_after_hacks_are_applied.append(hack_line) + + line = line.replace(arg_name, new_name) + + new_arg_types = [] + + arg_types = line.split(":")[1].split("->")[0].strip()[1:-1] + for arg in arg_types.split(", "): + if arg.startswith("tensor"): + new_arg_types.append(arg) + else: + new_arg_types.append(f"tensor<1x{arg}>") + + line = line.replace(arg_types, ", ".join(new_arg_types)) + + module_lines_after_hacks_are_applied.append(line) + + return "\n".join(module_lines_after_hacks_are_applied).strip() diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py new file mode 100644 index 000000000..bb36e2d93 --- /dev/null +++ b/concrete/numpy/mlir/node_converter.py @@ -0,0 +1,800 @@ +""" +Declaration of `NodeConverter` class. +""" + +# pylint: disable=no-member,no-name-in-module + +from typing import Dict, List, Tuple + +import numpy as np +from concrete.lang.dialects import fhe, fhelinalg +from concrete.lang.dialects.fhe import EncryptedIntegerType +from mlir.dialects import arith, linalg, tensor +from mlir.ir import ( + ArrayAttr, + Attribute, + BoolAttr, + Context, + DenseElementsAttr, + IndexType, + IntegerAttr, + IntegerType, + OpResult, + RankedTensorType, + Type, +) + +from ..dtypes import Integer +from ..internal.utils import assert_that +from ..representation import Graph, Node, Operation +from ..values import Value +from .utils import construct_deduplicated_tables + +# pylint: enable=no-member,no-name-in-module + + +class NodeConverter: + """ + NodeConverter class, to convert computation graph nodes to their MLIR equivalent. + """ + + ctx: Context + graph: Graph + node: Node + preds: List[OpResult] + + all_of_the_inputs_are_encrypted: bool + all_of_the_inputs_are_tensors: bool + one_of_the_inputs_is_a_tensor: bool + + nodes_to_mlir_names: Dict[Node, str] + mlir_names_to_mlir_types: Dict[str, str] + scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] + + @staticmethod + def value_to_mlir_type(ctx: Context, value: Value) -> Type: + """ + Convert a `Value` to its corresponding MLIR `Type`. + + Args: + ctx (Context): + MLIR context to perform the conversion + + value (Value): + value to convert + + Returns: + Type: + MLIR `Type` corresponding to `value` + """ + + dtype = value.dtype + + if isinstance(dtype, Integer): + if value.is_encrypted: + result = EncryptedIntegerType.get(ctx, dtype.bit_width) + else: + result = IntegerType.get_signless(dtype.bit_width) + + return result if value.is_scalar else RankedTensorType.get(value.shape, result) + + # the branch above is always taken due to compatibility checks + # still, it's a good idea to raise an appropriate error, just in case + + raise ValueError(f"{value} cannot be converted to MLIR") # pragma: no cover + + def __init__( + self, + ctx: Context, + graph: Graph, + node: Node, + preds: List[OpResult], + nodes_to_mlir_names: Dict[OpResult, str], + mlir_names_to_mlir_types: Dict[str, str], + scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]], + ): + self.ctx = ctx + self.graph = graph + self.node = node + self.preds = preds + + self.all_of_the_inputs_are_encrypted = True + self.all_of_the_inputs_are_tensors = True + self.one_of_the_inputs_is_a_tensor = False + + for inp in node.inputs: + if not inp.is_encrypted: + self.all_of_the_inputs_are_encrypted = False + + if inp.is_scalar: + self.all_of_the_inputs_are_tensors = False + else: + self.one_of_the_inputs_is_a_tensor = True + + self.nodes_to_mlir_names = nodes_to_mlir_names + self.mlir_names_to_mlir_types = mlir_names_to_mlir_types + self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks + + def convert(self) -> OpResult: + """ + Convert a node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + # pylint: disable=too-many-branches + + if self.node.operation == Operation.Constant: + result = self.convert_constant() + else: + assert_that(self.node.operation == Operation.Generic) + + name = self.node.properties["name"] + + if name == "add": + result = self.convert_add() + + elif name == "concatenate": + result = self.convert_concat() + + elif name == "conv2d": + result = self.convert_conv2d() + + elif name == "dot": + result = self.convert_dot() + + elif name == "index.static": + result = self.convert_static_indexing() + + elif name == "matmul": + result = self.convert_matmul() + + elif name == "multiply": + result = self.convert_mul() + + elif name == "negative": + result = self.convert_neg() + + elif name == "reshape": + result = self.convert_reshape() + + elif name == "subtract": + result = self.convert_sub() + + elif name == "sum": + result = self.convert_sum() + + else: + result = self.convert_tlu() + + mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip() + + self.nodes_to_mlir_names[self.node] = mlir_name + self.mlir_names_to_mlir_types[mlir_name] = str(result.type) + + if self.node.operation == Operation.Generic: + name = self.node.properties["name"] + if name in ["add", "dot", "multiply", "subtract"]: + if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors: + to_be_converted = [] + for pred in self.graph.ordered_preds_of(self.node): + if pred.output.is_scalar: + to_be_converted.append(self.nodes_to_mlir_names[pred]) + self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted + + return result + + # pylint: enable=too-many-branches + + def convert_add(self) -> OpResult: + """ + Convert "add" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + preds = self.preds + + if self.all_of_the_inputs_are_encrypted: + if self.one_of_the_inputs_is_a_tensor: + result = fhelinalg.AddEintOp(resulting_type, *preds).result + else: + result = fhe.AddEintOp(resulting_type, *preds).result + else: + if self.node.inputs[0].is_clear: + preds = preds[::-1] + + if self.one_of_the_inputs_is_a_tensor: + result = fhelinalg.AddEintIntOp(resulting_type, *preds).result + else: + result = fhe.AddEintIntOp(resulting_type, *preds).result + + return result + + def convert_concat(self) -> OpResult: + """ + Convert "concatenate" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + axis = self.node.properties["kwargs"].get("axis", 0) + + if axis is not None: + if axis < 0: + axis += len(self.node.inputs[0].shape) + return fhelinalg.ConcatOp( + resulting_type, + self.preds, + IntegerAttr.get(IntegerType.get_signless(64), axis), + ).result + + flattened_preds = [] + for pred, input_value in zip(self.preds, self.node.inputs): + input_shape = input_value.shape + input_size = np.prod(input_shape) + + flattened_pred_type = RankedTensorType.get( + [input_size], + NodeConverter.value_to_mlir_type( + self.ctx, + Value(input_value.dtype, shape=(), is_encrypted=input_value.is_encrypted), + ), + ) + flattened_pred = linalg.TensorCollapseShapeOp( + flattened_pred_type, + pred, + ArrayAttr.get( + [ + ArrayAttr.get( + [ + IntegerAttr.get(IndexType.parse("index"), i) + for i in range(len(input_shape)) + ] + ) + ] + ), + ).result + flattened_preds.append(flattened_pred) + + return fhelinalg.ConcatOp( + resulting_type, + flattened_preds, + IntegerAttr.get(IntegerType.get_signless(64), 0), + ).result + + def convert_constant(self) -> OpResult: + """ + Convert Operation.Constant node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + data = self.node() + + if self.node.output.is_scalar: + attr = IntegerAttr.get(resulting_type, data) + else: + # usage of `Attribute.parse` is the result of some limitations in the MLIR module + # provided by LLVM + + # what should have been used is `DenseElementsAttr` but it's impossible to assign + # custom bit-widths using it (e.g., uint5) + + # since we couldn't create a `DenseElementsAttr` with a custom bit width using + # the python api we use `Attribute.parse` to let the underlying library do it by itself + + attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}") + + return arith.ConstantOp(resulting_type, attr).result + + def convert_conv2d(self) -> OpResult: + """ + Convert "conv2d" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + preds = self.preds + + integer_type = IntegerType.get_signless(64, context=self.ctx) + + strides = DenseElementsAttr.get( + np.array(list(self.node.properties["kwargs"]["strides"]), dtype=np.uint64), + type=integer_type, + context=self.ctx, + ) + dilations = DenseElementsAttr.get( + np.array(list(self.node.properties["kwargs"]["dilations"]), dtype=np.uint64), + type=integer_type, + context=self.ctx, + ) + pads = DenseElementsAttr.get( + np.array(list(self.node.properties["kwargs"]["pads"]), dtype=np.uint64), + type=integer_type, + context=self.ctx, + ) + + has_bias = len(self.node.inputs) == 3 + if not has_bias: + preds.append(None) + + return fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result + + def convert_dot(self) -> OpResult: + """ + Convert "dot" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + preds = self.preds + + if self.node.inputs[0].is_clear: + preds = preds[::-1] + + if self.all_of_the_inputs_are_tensors: + # numpy.dot(x, y) where x and y are both vectors = regular dot product + result = fhelinalg.Dot(resulting_type, *preds).result + + elif not self.one_of_the_inputs_is_a_tensor: + # numpy.dot(x, y) where x and y are both scalars = x * y + result = fhe.MulEintIntOp(resulting_type, *preds).result + + else: + # numpy.dot(x, y) where one of x or y is a scalar and the other one is a vector = x * y + result = fhelinalg.MulEintIntOp(resulting_type, *preds).result + + return result + + def convert_matmul(self) -> OpResult: + """Convert a MatMul node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + preds = self.preds + + if self.node.output.shape == (): + if self.node.inputs[0].is_clear: + preds = preds[::-1] + result = fhelinalg.Dot(resulting_type, *preds).result + + elif self.node.inputs[0].is_clear: + result = fhelinalg.MatMulIntEintOp(resulting_type, *preds).result + else: + result = fhelinalg.MatMulEintIntOp(resulting_type, *preds).result + + return result + + def convert_mul(self) -> OpResult: + """ + Convert "multiply" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + preds = self.preds + + if self.node.inputs[0].is_clear: + preds = preds[::-1] + + if self.one_of_the_inputs_is_a_tensor: + result = fhelinalg.MulEintIntOp(resulting_type, *preds).result + else: + result = fhe.MulEintIntOp(resulting_type, *preds).result + + return result + + def convert_neg(self) -> OpResult: + """ + Convert "negative" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + pred = self.preds[0] + + if self.one_of_the_inputs_is_a_tensor: + result = fhelinalg.NegEintOp(resulting_type, pred).result + else: + result = fhe.NegEintOp(resulting_type, pred).result + + return result + + def convert_reshape(self) -> OpResult: + """ + Convert "reshape" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + input_shape = self.node.inputs[0].shape + output_shape = self.node.output.shape + + pred = self.preds[0] + if input_shape == output_shape: + return pred + + # we can either collapse or expand, which changes the number of dimensions + # this is a limitation of the current compiler, it will be improved in the future (#1060) + can_be_converted_directly = len(input_shape) != len(output_shape) + + reassociation: List[List[int]] = [] + if can_be_converted_directly: + if len(output_shape) == 1: + # output is 1 dimensional so collapse every dimension into the same dimension + reassociation.append(list(range(len(input_shape)))) + else: + # input is m dimensional + # output is n dimensional + # and m is different from n + + # we don't want to duplicate code, so we forget about input and output, + # and we focus on smaller shape and bigger shape + + smaller_shape, bigger_shape = ( + (output_shape, input_shape) + if len(output_shape) < len(input_shape) + else (input_shape, output_shape) + ) + s_index, b_index = 0, 0 + + # now we will figure out how to group the bigger shape to get the smaller shape + # think of the algorithm below as + # keep merging the dimensions of the bigger shape + # until we have a match on the smaller shape + # then try to match the next dimension of the smaller shape + # if all dimensions of the smaller shape is matched + # we can convert it + + group = [] + size = 1 + while s_index < len(smaller_shape) and b_index < len(bigger_shape): + # dimension `b_index` of `bigger_shape` belongs to current group + group.append(b_index) + + # and current group has `size * bigger_shape[b_index]` elements now + size *= bigger_shape[b_index] + + # if current group size matches the dimension `s_index` of `smaller_shape` + if size == smaller_shape[s_index]: + # we finalize this group and reset everything + size = 1 + reassociation.append(group) + group = [] + + # now try to match the next dimension of `smaller_shape` + s_index += 1 + + # now process the next dimension of `bigger_shape` + b_index += 1 + + # handle the case where bigger shape has proceeding 1s + # e.g., (5,) -> (5, 1) + while b_index < len(bigger_shape) and bigger_shape[b_index] == 1: + reassociation[-1].append(b_index) + b_index += 1 + + # if not all dimensions of both shapes are processed exactly + if s_index != len(smaller_shape) or b_index != len(bigger_shape): + # we cannot convert + can_be_converted_directly = False + + index_type = IndexType.parse("index") + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + + if can_be_converted_directly: + reassociation_attr = ArrayAttr.get( + [ + ArrayAttr.get([IntegerAttr.get(index_type, dimension) for dimension in group]) + for group in reassociation + ] + ) + if len(output_shape) < len(input_shape): + return linalg.TensorCollapseShapeOp(resulting_type, pred, reassociation_attr).result + return linalg.TensorExpandShapeOp(resulting_type, pred, reassociation_attr).result + + flattened_type = NodeConverter.value_to_mlir_type( + self.ctx, + Value( + dtype=self.node.inputs[0].dtype, + shape=(np.prod(input_shape),), + is_encrypted=self.node.inputs[0].is_encrypted, + ), + ) + flattened_result = linalg.TensorCollapseShapeOp( + flattened_type, + pred, + ArrayAttr.get( + [ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(input_shape))])] + ), + ).result + + return linalg.TensorExpandShapeOp( + resulting_type, + flattened_result, + ArrayAttr.get( + [ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(output_shape))])] + ), + ).result + + def convert_static_indexing(self) -> OpResult: + """ + Convert "index.static" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + input_value = self.node.inputs[0] + input_shape = input_value.shape + + index = list(self.node.properties["attributes"]["index"]) + index_type = IndexType.parse("index") + + while len(index) < input_value.ndim: + index.append(slice(None, None, None)) + + output_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + if len(index) == len(input_shape) and all(isinstance(i, (int, np.integer)) for i in index): + indices = [] + for value, dimension_size in zip(index, input_shape): + value = int(value) + attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size) + indices.append(arith.ConstantOp(index_type, attr).result) + return tensor.ExtractOp(output_type, self.preds[0], indices).result + + offsets = [] + sizes = [] + strides = [] + + destroyed_dimensions = [] + for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)): + + if isinstance(indexing_element, slice): + size = np.zeros(dimension_size)[indexing_element].shape[0] + stride = indexing_element.step if isinstance(indexing_element.step, int) else 1 + offset = ( + ( + indexing_element.start + if indexing_element.start >= 0 + else indexing_element.start + dimension_size + ) + if isinstance(indexing_element.start, int) + else (0 if stride > 0 else dimension_size - 1) + ) + + else: + destroyed_dimensions.append(dimension) + size = 1 + stride = 1 + offset = int( + indexing_element if indexing_element >= 0 else indexing_element + dimension_size + ) + + offsets.append(offset) + sizes.append(size) + strides.append(stride) + + if len(destroyed_dimensions) == 0: + return tensor.ExtractSliceOp( + output_type, + self.preds[0], + [], + [], + [], + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]), + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]), + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]), + ).result + + output_value = self.node.output + + intermediate_shape = list(output_value.shape) + for dimension in destroyed_dimensions: + intermediate_shape.insert(dimension, 1) + + intermediate = tensor.ExtractSliceOp( + RankedTensorType.get( + intermediate_shape, + NodeConverter.value_to_mlir_type( + self.ctx, + Value(output_value.dtype, shape=(), is_encrypted=output_value.is_encrypted), + ), + ), + self.preds[0], + [], + [], + [], + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]), + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]), + ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]), + ).result + + reassociaton = [] + + current_intermediate_dimension = 0 + for _ in range(len(output_value.shape)): + indices = [current_intermediate_dimension] + while current_intermediate_dimension in destroyed_dimensions: + current_intermediate_dimension += 1 + indices.append(current_intermediate_dimension) + + reassociaton.append(indices) + current_intermediate_dimension += 1 + while current_intermediate_dimension < len(intermediate_shape): + reassociaton[-1].append(current_intermediate_dimension) + current_intermediate_dimension += 1 + + return linalg.TensorCollapseShapeOp( + output_type, + intermediate, + ArrayAttr.get( + [ + ArrayAttr.get( + [IntegerAttr.get(index_type, index) for index in indices], + ) + for indices in reassociaton + ], + ), + ).result + + def convert_sub(self) -> OpResult: + """ + Convert "subtract" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + preds = self.preds + + if self.one_of_the_inputs_is_a_tensor: + result = fhelinalg.SubIntEintOp(resulting_type, *preds).result + else: + result = fhe.SubIntEintOp(resulting_type, *preds).result + + return result + + def convert_sum(self) -> OpResult: + """ + Convert "sum" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + + axes = self.node.properties["kwargs"].get("axis", []) + keep_dims = self.node.properties["kwargs"].get("keepdims", False) + + if isinstance(axes, int): + axes = [axes] + elif isinstance(axes, tuple): + axes = list(axes) + + input_dimensions = self.node.inputs[0].ndim + for i, axis in enumerate(axes): + if axis < 0: + axes[i] += input_dimensions + + return fhelinalg.SumOp( + resulting_type, + self.preds[0], + ArrayAttr.get([IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]), + BoolAttr.get(keep_dims), + ).result + + def convert_tlu(self) -> OpResult: + """ + Convert Operation.Generic node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + variable_input_index = -1 + + preds = self.graph.ordered_preds_of(self.node) + for i, pred in enumerate(preds): + if pred.operation != Operation.Constant: + variable_input_index = i + break + + assert_that(variable_input_index != -1) + + tables = construct_deduplicated_tables(self.node, preds) + assert_that(len(tables) > 0) + + lut_shape: Tuple[int, ...] = () + map_shape: Tuple[int, ...] = () + + if len(tables) == 1: + table = tables[0][0] + + # The reduction on 63b is to avoid problems like doing a TLU of + # the form T[j] = 2< bool: + return isinstance(other, HashableNdarray) and np.array_equal(self.array, other.array) + + def __hash__(self) -> int: + return hash(self.array.tobytes()) + + +def flood_replace_none_values(table: list): + """ + Use flooding algorithm to replace `None` values. + + Args: + table (list): + the list in which there are `None` values that need to be replaced + with copies of the closest non `None` data from the list + """ + + assert_that(any(value is not None for value in table)) + + not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None) + while not_none_values_idx: + + current_idx = not_none_values_idx.popleft() + current_value = table[current_idx] + + previous_idx = current_idx - 1 + next_idx = current_idx + 1 + + if previous_idx >= 0 and table[previous_idx] is None: + table[previous_idx] = deepcopy(current_value) + not_none_values_idx.append(previous_idx) + + if next_idx < len(table) and table[next_idx] is None: + table[next_idx] = deepcopy(current_value) + not_none_values_idx.append(next_idx) + + assert_that(all(value is not None for value in table)) + + +def construct_table(node: Node, preds: List[Node]) -> List[Any]: + """ + Construct the lookup table for an Operation.Generic node. + + Args: + node (Node): + Operation.Generic to construct the table + + preds (List[Node]): + ordered predecessors to `node` + + Returns: + List[Any]: + lookup table corresponding to `node` and its input value + """ + + variable_input_index = -1 + for index, pred in enumerate(preds): + if pred.operation != Operation.Constant: + variable_input_index = index + break + assert_that(variable_input_index != -1) + + variable_input_dtype = node.inputs[variable_input_index].dtype + variable_input_shape = node.inputs[variable_input_index].shape + + assert_that(isinstance(variable_input_dtype, Integer)) + variable_input_dtype = cast(Integer, variable_input_dtype) + + inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds] + + table: List[Optional[Union[np.bool_, np.integer, np.floating, np.ndarray]]] = [] + for value in range(variable_input_dtype.min(), variable_input_dtype.max() + 1): + try: + inputs[variable_input_index] = np.ones(variable_input_shape, dtype=np.int64) * value + table.append(node(*inputs)) + except Exception: # pylint: disable=broad-except + # here we try our best to fill the table + # if it fails, we append None and let flooding algoritm replace None values below + table.append(None) + + flood_replace_none_values(table) + + return table + + +def construct_deduplicated_tables( + node: Node, + preds: List[Node], +) -> Tuple[Tuple[np.ndarray, List[Tuple[int, ...]]], ...]: + """ + Construct lookup tables for each cell of the input for an Operation.Generic node. + + Args: + node (Node): + Operation.Generic to construct the table + + preds (List[Node]): + ordered predecessors to `node` + + Returns: + Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: + tuple containing tuples of 2 for + - constructed table + - list of indices of the input that use the constructed table + + e.g., + + .. code-block:: python + + ( + (np.array([3, 1, 2, 4]), [(1, 0), (2, 1)]), + (np.array([5, 8, 6, 7]), [(0, 0), (0, 1), (1, 1), (2, 0)]), + ) + + means the lookup on 3x2 input will result in + + .. code-block:: python + + [ [5, 8, 6, 7][input[0, 0]] , [5, 8, 6, 7][input[0, 1]] ] + [ [3, 1, 2, 4][input[1, 0]] , [5, 8, 6, 7][input[1, 1]] ] + [ [5, 8, 6, 7][input[2, 0]] , [3, 1, 2, 4][input[2, 1]] ] + """ + + node_complete_table = np.concatenate( + tuple(np.expand_dims(array, -1) for array in construct_table(node, preds)), + axis=-1, + ) + + all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1])) + tables_to_cell_idx: DefaultDict[HashableNdarray, List[Tuple[int, ...]]] = defaultdict(list) + + idx: Tuple[int, ...] + all_idx_set = set() + for idx in all_cells_idx: + hashable_array = HashableNdarray(node_complete_table[idx]) + tables_to_cell_idx[hashable_array].append(idx) + all_idx_set.add(idx) + + assert_that(len(all_idx_set) == math.prod(node_complete_table.shape[:-1])) + + return tuple( + (hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items() + )