feat: implement mlir module

This commit is contained in:
Umut
2022-04-04 13:29:09 +02:00
parent 58328aa42f
commit 92651a12ee
4 changed files with 1358 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
"""
Declaration of `concrete.numpy.mlir` namespace.
"""
from .graph_converter import GraphConverter
from .node_converter import NodeConverter

View File

@@ -0,0 +1,380 @@
"""
Declaration of `GraphConverter` class.
"""
# pylint: disable=no-member,no-name-in-module
from copy import deepcopy
from typing import Dict, List, Optional, cast
import concrete.lang as concretelang
import networkx as nx
import numpy as np
from mlir.dialects import builtin
from mlir.ir import Context, InsertionPoint, Location, Module
from ..dtypes import Integer, SignedInteger
from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..values import ClearScalar
from .node_converter import NodeConverter
from .utils import MAXIMUM_BIT_WIDTH
# pylint: enable=no-member,no-name-in-module
class GraphConverter:
"""
GraphConverter class, to convert computation graphs to their MLIR equivalent.
"""
@staticmethod
def _check_node_convertibility(graph: Graph, node: Node) -> Optional[str]:
"""
Check node convertibility to MLIR.
Args:
graph (Graph):
computation graph of the node
node (Node):
node to be checked
Returns:
Optional[str]:
None if node is convertible to MLIR, the reason for inconvertibility otherwise
"""
# pylint: disable=too-many-branches,too-many-return-statements
inputs = node.inputs
output = node.output
if node.operation == Operation.Constant:
assert_that(len(inputs) == 0)
if not isinstance(output.dtype, Integer):
return "only integer constants are supported"
elif node.operation == Operation.Input:
assert_that(len(inputs) == 1)
assert_that(inputs[0] == output)
if not isinstance(output.dtype, Integer) or output.dtype.is_signed:
return "only unsigned integer inputs are supported"
else:
assert_that(node.operation == Operation.Generic)
if not isinstance(output.dtype, Integer):
return "only integer operations are supported"
name = node.properties["name"]
if name == "add":
assert_that(len(inputs) == 2)
elif name == "concatenate":
if not all(input.is_encrypted for input in inputs):
return "only all encrypted concatenate is supported"
elif name == "conv2d":
assert_that(len(inputs) == 2 or len(inputs) == 3)
if not (inputs[0].is_encrypted and inputs[1].is_clear):
return "only conv2d with encrypted input and clear weight is supported"
elif name == "dot":
assert_that(len(inputs) == 2)
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only dot product between encrypted and clear is supported"
elif name == "index.static":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted indexing supported"
elif name == "matmul":
assert_that(len(inputs) == 2)
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only matrix multiplication between encrypted and clear is supported"
elif name == "multiply":
assert_that(len(inputs) == 2)
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only multiplication between encrypted and clear is supported"
elif name == "negative":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted negation is supported"
elif name == "reshape":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted reshape is supported"
elif name == "subtract":
assert_that(len(inputs) == 2)
if not (inputs[0].is_clear and inputs[1].is_encrypted):
return "only subtraction of encrypted from clear is supported"
elif name == "sum":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
return "only encrypted sum is supported"
else:
variable_input_indices = [
idx
for idx, pred in enumerate(graph.ordered_preds_of(node))
if not pred.operation == Operation.Constant
]
if len(variable_input_indices) != 1:
return "only single input table lookups are supported"
if all(input.is_clear for input in inputs):
return "one of the operands must be encrypted"
return None
# pylint: enable=too-many-branches,too-many-return-statements
@staticmethod
def _check_graph_convertibility(graph: Graph):
"""
Check graph convertibility to MLIR.
Args:
graph (Graph):
computation graph to be checked
Raises:
RuntimeError:
if `graph` is not convertible to MLIR
"""
offending_nodes = {}
if len(graph.output_nodes) > 1:
offending_nodes.update(
{
node: ["only a single output is supported"]
for node in graph.output_nodes.values()
}
)
if len(offending_nodes) == 0:
for node in graph.graph.nodes:
if (reason := GraphConverter._check_node_convertibility(graph, node)) is not None:
offending_nodes[node] = [reason]
if len(offending_nodes) != 0:
raise RuntimeError(
"Function you are trying to compile cannot be converted to MLIR\n\n"
+ graph.format(highlighted_nodes=offending_nodes)
)
@staticmethod
def _update_bit_widths(graph: Graph):
"""
Update bit-widths in a computation graph to be convertible to MLIR.
Args:
graph (Graph):
computation graph to be updated
"""
offending_nodes: Dict[Node, List[str]] = {}
max_bit_width = 0
for node in graph.graph.nodes:
dtype = node.output.dtype
assert_that(isinstance(dtype, Integer))
current_node_bit_width = (
dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width
)
max_bit_width = max(max_bit_width, current_node_bit_width)
if current_node_bit_width > MAXIMUM_BIT_WIDTH:
offending_nodes[node] = [
f"only up to {MAXIMUM_BIT_WIDTH}-bit integers are supported"
]
if len(offending_nodes) != 0:
raise RuntimeError(
"Function you are trying to compile cannot be converted to MLIR:\n\n"
+ graph.format(highlighted_nodes=offending_nodes)
)
for node in graph.graph.nodes:
for value in node.inputs + [node.output]:
dtype = value.dtype
assert_that(isinstance(dtype, Integer))
dtype.bit_width = max_bit_width + 1 if value.is_clear else max_bit_width
@staticmethod
def _offset_negative_lookup_table_inputs(graph: Graph):
"""
Offset negative table lookup inputs to be convertible to MLIR.
Args:
graph (Graph):
computation graph to apply offset
"""
# ugly hack to add an offset before entering a TLU
# if its variable input node has a signed output.
# this makes hardcoded assumptions about the way bit widths are handled in MLIR.
# this does not update the TLU input values to allow for proper table generation.
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic:
if not node.converted_to_table_lookup:
continue
variable_input_index = -1
preds = graph.ordered_preds_of(node)
for index, pred in enumerate(preds):
if pred.operation != Operation.Constant:
variable_input_index = index
break
variable_input_node = preds[variable_input_index]
variable_input_value = variable_input_node.output
variable_input_dtype = variable_input_value.dtype
assert_that(isinstance(variable_input_dtype, Integer))
variable_input_dtype = cast(Integer, variable_input_dtype)
if not variable_input_dtype.is_signed:
continue
variable_input_bit_width = variable_input_dtype.bit_width
offset_constant_dtype = SignedInteger(variable_input_bit_width + 1)
offset_constant = Node.constant(abs(variable_input_dtype.min()))
offset_constant.output.dtype = offset_constant_dtype
add_offset = Node.generic(
"add",
[variable_input_value, ClearScalar(offset_constant_dtype)],
variable_input_value,
np.add,
)
nx_graph.remove_edge(variable_input_node, node)
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0)
nx_graph.add_edge(offset_constant, add_offset, input_idx=1)
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
@staticmethod
def convert(graph: Graph) -> str:
"""
Convert a computation graph to its corresponding MLIR representation.
Args:
graph (Graph):
computation graph to be converted
Returns:
str:
textual MLIR representation corresponding to `graph`
"""
graph = deepcopy(graph)
GraphConverter._check_graph_convertibility(graph)
GraphConverter._update_bit_widths(graph)
GraphConverter._offset_negative_lookup_table_inputs(graph)
# There are no tensor +*- scalar operations in the compiler
# But such operations are used commonly, so we need to support them
# So, we implemented some workarounds (pull request #970)
# Once we have native support, this workaround shall be removed (issue #837)
# (most changes in #970 shall be reverted)
# { node1: "%arg0", node2: "%0", node3: "%1" }
nodes_to_mlir_names: Dict[Node, str] = {}
# { "%arg0": "i5", "%0": "tensor<2x3x!FHE.eint<4>>" }
mlir_names_to_mlir_types: Dict[str, str] = {}
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
with Context() as ctx, Location.unknown():
concretelang.register_dialects(ctx)
module = Module.create()
with InsertionPoint(module.body):
parameters = [
NodeConverter.value_to_mlir_type(ctx, input_node.output)
for input_node in graph.ordered_inputs()
]
@builtin.FuncOp.from_py_func(*parameters)
def main(*arg):
ir_to_mlir = {}
for arg_num, node in graph.input_nodes.items():
ir_to_mlir[node] = arg[arg_num]
mlir_name = f"%arg{arg_num}"
nodes_to_mlir_names[node] = mlir_name
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])
for node in nx.topological_sort(graph.graph):
if node.operation == Operation.Input:
continue
preds = [ir_to_mlir[pred] for pred in graph.ordered_preds_of(node)]
node_converter = NodeConverter(
ctx,
graph,
node,
preds,
nodes_to_mlir_names,
mlir_names_to_mlir_types,
scalar_to_1d_tensor_conversion_hacks,
)
ir_to_mlir[node] = node_converter.convert()
results = (ir_to_mlir[output_node] for output_node in graph.ordered_outputs())
return results
module_lines_after_hacks_are_applied = []
for line in str(module).split("\n"):
mlir_name = line.split("=")[0].strip()
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
module_lines_after_hacks_are_applied.append(line)
continue
to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name]
for arg_name in to_be_replaced:
new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}"
mlir_type = mlir_names_to_mlir_types[arg_name]
hack_line = (
f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>"
)
module_lines_after_hacks_are_applied.append(hack_line)
line = line.replace(arg_name, new_name)
new_arg_types = []
arg_types = line.split(":")[1].split("->")[0].strip()[1:-1]
for arg in arg_types.split(", "):
if arg.startswith("tensor"):
new_arg_types.append(arg)
else:
new_arg_types.append(f"tensor<1x{arg}>")
line = line.replace(arg_types, ", ".join(new_arg_types))
module_lines_after_hacks_are_applied.append(line)
return "\n".join(module_lines_after_hacks_are_applied).strip()

View File

@@ -0,0 +1,800 @@
"""
Declaration of `NodeConverter` class.
"""
# pylint: disable=no-member,no-name-in-module
from typing import Dict, List, Tuple
import numpy as np
from concrete.lang.dialects import fhe, fhelinalg
from concrete.lang.dialects.fhe import EncryptedIntegerType
from mlir.dialects import arith, linalg, tensor
from mlir.ir import (
ArrayAttr,
Attribute,
BoolAttr,
Context,
DenseElementsAttr,
IndexType,
IntegerAttr,
IntegerType,
OpResult,
RankedTensorType,
Type,
)
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..values import Value
from .utils import construct_deduplicated_tables
# pylint: enable=no-member,no-name-in-module
class NodeConverter:
"""
NodeConverter class, to convert computation graph nodes to their MLIR equivalent.
"""
ctx: Context
graph: Graph
node: Node
preds: List[OpResult]
all_of_the_inputs_are_encrypted: bool
all_of_the_inputs_are_tensors: bool
one_of_the_inputs_is_a_tensor: bool
nodes_to_mlir_names: Dict[Node, str]
mlir_names_to_mlir_types: Dict[str, str]
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
@staticmethod
def value_to_mlir_type(ctx: Context, value: Value) -> Type:
"""
Convert a `Value` to its corresponding MLIR `Type`.
Args:
ctx (Context):
MLIR context to perform the conversion
value (Value):
value to convert
Returns:
Type:
MLIR `Type` corresponding to `value`
"""
dtype = value.dtype
if isinstance(dtype, Integer):
if value.is_encrypted:
result = EncryptedIntegerType.get(ctx, dtype.bit_width)
else:
result = IntegerType.get_signless(dtype.bit_width)
return result if value.is_scalar else RankedTensorType.get(value.shape, result)
# the branch above is always taken due to compatibility checks
# still, it's a good idea to raise an appropriate error, just in case
raise ValueError(f"{value} cannot be converted to MLIR") # pragma: no cover
def __init__(
self,
ctx: Context,
graph: Graph,
node: Node,
preds: List[OpResult],
nodes_to_mlir_names: Dict[OpResult, str],
mlir_names_to_mlir_types: Dict[str, str],
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
):
self.ctx = ctx
self.graph = graph
self.node = node
self.preds = preds
self.all_of_the_inputs_are_encrypted = True
self.all_of_the_inputs_are_tensors = True
self.one_of_the_inputs_is_a_tensor = False
for inp in node.inputs:
if not inp.is_encrypted:
self.all_of_the_inputs_are_encrypted = False
if inp.is_scalar:
self.all_of_the_inputs_are_tensors = False
else:
self.one_of_the_inputs_is_a_tensor = True
self.nodes_to_mlir_names = nodes_to_mlir_names
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
def convert(self) -> OpResult:
"""
Convert a node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
# pylint: disable=too-many-branches
if self.node.operation == Operation.Constant:
result = self.convert_constant()
else:
assert_that(self.node.operation == Operation.Generic)
name = self.node.properties["name"]
if name == "add":
result = self.convert_add()
elif name == "concatenate":
result = self.convert_concat()
elif name == "conv2d":
result = self.convert_conv2d()
elif name == "dot":
result = self.convert_dot()
elif name == "index.static":
result = self.convert_static_indexing()
elif name == "matmul":
result = self.convert_matmul()
elif name == "multiply":
result = self.convert_mul()
elif name == "negative":
result = self.convert_neg()
elif name == "reshape":
result = self.convert_reshape()
elif name == "subtract":
result = self.convert_sub()
elif name == "sum":
result = self.convert_sum()
else:
result = self.convert_tlu()
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
self.nodes_to_mlir_names[self.node] = mlir_name
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
if self.node.operation == Operation.Generic:
name = self.node.properties["name"]
if name in ["add", "dot", "multiply", "subtract"]:
if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors:
to_be_converted = []
for pred in self.graph.ordered_preds_of(self.node):
if pred.output.is_scalar:
to_be_converted.append(self.nodes_to_mlir_names[pred])
self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted
return result
# pylint: enable=too-many-branches
def convert_add(self) -> OpResult:
"""
Convert "add" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
if self.all_of_the_inputs_are_encrypted:
if self.one_of_the_inputs_is_a_tensor:
result = fhelinalg.AddEintOp(resulting_type, *preds).result
else:
result = fhe.AddEintOp(resulting_type, *preds).result
else:
if self.node.inputs[0].is_clear:
preds = preds[::-1]
if self.one_of_the_inputs_is_a_tensor:
result = fhelinalg.AddEintIntOp(resulting_type, *preds).result
else:
result = fhe.AddEintIntOp(resulting_type, *preds).result
return result
def convert_concat(self) -> OpResult:
"""
Convert "concatenate" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
axis = self.node.properties["kwargs"].get("axis", 0)
if axis is not None:
if axis < 0:
axis += len(self.node.inputs[0].shape)
return fhelinalg.ConcatOp(
resulting_type,
self.preds,
IntegerAttr.get(IntegerType.get_signless(64), axis),
).result
flattened_preds = []
for pred, input_value in zip(self.preds, self.node.inputs):
input_shape = input_value.shape
input_size = np.prod(input_shape)
flattened_pred_type = RankedTensorType.get(
[input_size],
NodeConverter.value_to_mlir_type(
self.ctx,
Value(input_value.dtype, shape=(), is_encrypted=input_value.is_encrypted),
),
)
flattened_pred = linalg.TensorCollapseShapeOp(
flattened_pred_type,
pred,
ArrayAttr.get(
[
ArrayAttr.get(
[
IntegerAttr.get(IndexType.parse("index"), i)
for i in range(len(input_shape))
]
)
]
),
).result
flattened_preds.append(flattened_pred)
return fhelinalg.ConcatOp(
resulting_type,
flattened_preds,
IntegerAttr.get(IntegerType.get_signless(64), 0),
).result
def convert_constant(self) -> OpResult:
"""
Convert Operation.Constant node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
data = self.node()
if self.node.output.is_scalar:
attr = IntegerAttr.get(resulting_type, data)
else:
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
# provided by LLVM
# what should have been used is `DenseElementsAttr` but it's impossible to assign
# custom bit-widths using it (e.g., uint5)
# since we couldn't create a `DenseElementsAttr` with a custom bit width using
# the python api we use `Attribute.parse` to let the underlying library do it by itself
attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}")
return arith.ConstantOp(resulting_type, attr).result
def convert_conv2d(self) -> OpResult:
"""
Convert "conv2d" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
integer_type = IntegerType.get_signless(64, context=self.ctx)
strides = DenseElementsAttr.get(
np.array(list(self.node.properties["kwargs"]["strides"]), dtype=np.uint64),
type=integer_type,
context=self.ctx,
)
dilations = DenseElementsAttr.get(
np.array(list(self.node.properties["kwargs"]["dilations"]), dtype=np.uint64),
type=integer_type,
context=self.ctx,
)
pads = DenseElementsAttr.get(
np.array(list(self.node.properties["kwargs"]["pads"]), dtype=np.uint64),
type=integer_type,
context=self.ctx,
)
has_bias = len(self.node.inputs) == 3
if not has_bias:
preds.append(None)
return fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result
def convert_dot(self) -> OpResult:
"""
Convert "dot" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
if self.node.inputs[0].is_clear:
preds = preds[::-1]
if self.all_of_the_inputs_are_tensors:
# numpy.dot(x, y) where x and y are both vectors = regular dot product
result = fhelinalg.Dot(resulting_type, *preds).result
elif not self.one_of_the_inputs_is_a_tensor:
# numpy.dot(x, y) where x and y are both scalars = x * y
result = fhe.MulEintIntOp(resulting_type, *preds).result
else:
# numpy.dot(x, y) where one of x or y is a scalar and the other one is a vector = x * y
result = fhelinalg.MulEintIntOp(resulting_type, *preds).result
return result
def convert_matmul(self) -> OpResult:
"""Convert a MatMul node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
if self.node.output.shape == ():
if self.node.inputs[0].is_clear:
preds = preds[::-1]
result = fhelinalg.Dot(resulting_type, *preds).result
elif self.node.inputs[0].is_clear:
result = fhelinalg.MatMulIntEintOp(resulting_type, *preds).result
else:
result = fhelinalg.MatMulEintIntOp(resulting_type, *preds).result
return result
def convert_mul(self) -> OpResult:
"""
Convert "multiply" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
if self.node.inputs[0].is_clear:
preds = preds[::-1]
if self.one_of_the_inputs_is_a_tensor:
result = fhelinalg.MulEintIntOp(resulting_type, *preds).result
else:
result = fhe.MulEintIntOp(resulting_type, *preds).result
return result
def convert_neg(self) -> OpResult:
"""
Convert "negative" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
pred = self.preds[0]
if self.one_of_the_inputs_is_a_tensor:
result = fhelinalg.NegEintOp(resulting_type, pred).result
else:
result = fhe.NegEintOp(resulting_type, pred).result
return result
def convert_reshape(self) -> OpResult:
"""
Convert "reshape" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
input_shape = self.node.inputs[0].shape
output_shape = self.node.output.shape
pred = self.preds[0]
if input_shape == output_shape:
return pred
# we can either collapse or expand, which changes the number of dimensions
# this is a limitation of the current compiler, it will be improved in the future (#1060)
can_be_converted_directly = len(input_shape) != len(output_shape)
reassociation: List[List[int]] = []
if can_be_converted_directly:
if len(output_shape) == 1:
# output is 1 dimensional so collapse every dimension into the same dimension
reassociation.append(list(range(len(input_shape))))
else:
# input is m dimensional
# output is n dimensional
# and m is different from n
# we don't want to duplicate code, so we forget about input and output,
# and we focus on smaller shape and bigger shape
smaller_shape, bigger_shape = (
(output_shape, input_shape)
if len(output_shape) < len(input_shape)
else (input_shape, output_shape)
)
s_index, b_index = 0, 0
# now we will figure out how to group the bigger shape to get the smaller shape
# think of the algorithm below as
# keep merging the dimensions of the bigger shape
# until we have a match on the smaller shape
# then try to match the next dimension of the smaller shape
# if all dimensions of the smaller shape is matched
# we can convert it
group = []
size = 1
while s_index < len(smaller_shape) and b_index < len(bigger_shape):
# dimension `b_index` of `bigger_shape` belongs to current group
group.append(b_index)
# and current group has `size * bigger_shape[b_index]` elements now
size *= bigger_shape[b_index]
# if current group size matches the dimension `s_index` of `smaller_shape`
if size == smaller_shape[s_index]:
# we finalize this group and reset everything
size = 1
reassociation.append(group)
group = []
# now try to match the next dimension of `smaller_shape`
s_index += 1
# now process the next dimension of `bigger_shape`
b_index += 1
# handle the case where bigger shape has proceeding 1s
# e.g., (5,) -> (5, 1)
while b_index < len(bigger_shape) and bigger_shape[b_index] == 1:
reassociation[-1].append(b_index)
b_index += 1
# if not all dimensions of both shapes are processed exactly
if s_index != len(smaller_shape) or b_index != len(bigger_shape):
# we cannot convert
can_be_converted_directly = False
index_type = IndexType.parse("index")
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
if can_be_converted_directly:
reassociation_attr = ArrayAttr.get(
[
ArrayAttr.get([IntegerAttr.get(index_type, dimension) for dimension in group])
for group in reassociation
]
)
if len(output_shape) < len(input_shape):
return linalg.TensorCollapseShapeOp(resulting_type, pred, reassociation_attr).result
return linalg.TensorExpandShapeOp(resulting_type, pred, reassociation_attr).result
flattened_type = NodeConverter.value_to_mlir_type(
self.ctx,
Value(
dtype=self.node.inputs[0].dtype,
shape=(np.prod(input_shape),),
is_encrypted=self.node.inputs[0].is_encrypted,
),
)
flattened_result = linalg.TensorCollapseShapeOp(
flattened_type,
pred,
ArrayAttr.get(
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(input_shape))])]
),
).result
return linalg.TensorExpandShapeOp(
resulting_type,
flattened_result,
ArrayAttr.get(
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(output_shape))])]
),
).result
def convert_static_indexing(self) -> OpResult:
"""
Convert "index.static" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
input_value = self.node.inputs[0]
input_shape = input_value.shape
index = list(self.node.properties["attributes"]["index"])
index_type = IndexType.parse("index")
while len(index) < input_value.ndim:
index.append(slice(None, None, None))
output_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
if len(index) == len(input_shape) and all(isinstance(i, (int, np.integer)) for i in index):
indices = []
for value, dimension_size in zip(index, input_shape):
value = int(value)
attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size)
indices.append(arith.ConstantOp(index_type, attr).result)
return tensor.ExtractOp(output_type, self.preds[0], indices).result
offsets = []
sizes = []
strides = []
destroyed_dimensions = []
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
if isinstance(indexing_element, slice):
size = np.zeros(dimension_size)[indexing_element].shape[0]
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
offset = (
(
indexing_element.start
if indexing_element.start >= 0
else indexing_element.start + dimension_size
)
if isinstance(indexing_element.start, int)
else (0 if stride > 0 else dimension_size - 1)
)
else:
destroyed_dimensions.append(dimension)
size = 1
stride = 1
offset = int(
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
)
offsets.append(offset)
sizes.append(size)
strides.append(stride)
if len(destroyed_dimensions) == 0:
return tensor.ExtractSliceOp(
output_type,
self.preds[0],
[],
[],
[],
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
).result
output_value = self.node.output
intermediate_shape = list(output_value.shape)
for dimension in destroyed_dimensions:
intermediate_shape.insert(dimension, 1)
intermediate = tensor.ExtractSliceOp(
RankedTensorType.get(
intermediate_shape,
NodeConverter.value_to_mlir_type(
self.ctx,
Value(output_value.dtype, shape=(), is_encrypted=output_value.is_encrypted),
),
),
self.preds[0],
[],
[],
[],
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
).result
reassociaton = []
current_intermediate_dimension = 0
for _ in range(len(output_value.shape)):
indices = [current_intermediate_dimension]
while current_intermediate_dimension in destroyed_dimensions:
current_intermediate_dimension += 1
indices.append(current_intermediate_dimension)
reassociaton.append(indices)
current_intermediate_dimension += 1
while current_intermediate_dimension < len(intermediate_shape):
reassociaton[-1].append(current_intermediate_dimension)
current_intermediate_dimension += 1
return linalg.TensorCollapseShapeOp(
output_type,
intermediate,
ArrayAttr.get(
[
ArrayAttr.get(
[IntegerAttr.get(index_type, index) for index in indices],
)
for indices in reassociaton
],
),
).result
def convert_sub(self) -> OpResult:
"""
Convert "subtract" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
if self.one_of_the_inputs_is_a_tensor:
result = fhelinalg.SubIntEintOp(resulting_type, *preds).result
else:
result = fhe.SubIntEintOp(resulting_type, *preds).result
return result
def convert_sum(self) -> OpResult:
"""
Convert "sum" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
axes = self.node.properties["kwargs"].get("axis", [])
keep_dims = self.node.properties["kwargs"].get("keepdims", False)
if isinstance(axes, int):
axes = [axes]
elif isinstance(axes, tuple):
axes = list(axes)
input_dimensions = self.node.inputs[0].ndim
for i, axis in enumerate(axes):
if axis < 0:
axes[i] += input_dimensions
return fhelinalg.SumOp(
resulting_type,
self.preds[0],
ArrayAttr.get([IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]),
BoolAttr.get(keep_dims),
).result
def convert_tlu(self) -> OpResult:
"""
Convert Operation.Generic node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
variable_input_index = -1
preds = self.graph.ordered_preds_of(self.node)
for i, pred in enumerate(preds):
if pred.operation != Operation.Constant:
variable_input_index = i
break
assert_that(variable_input_index != -1)
tables = construct_deduplicated_tables(self.node, preds)
assert_that(len(tables) > 0)
lut_shape: Tuple[int, ...] = ()
map_shape: Tuple[int, ...] = ()
if len(tables) == 1:
table = tables[0][0]
# The reduction on 63b is to avoid problems like doing a TLU of
# the form T[j] = 2<<j, for j which is supposed to be 7b as per
# constraint of the compiler, while in practice, it is a small
# value. Reducing on 64b was not ok for some reason
lut_shape = (len(table),)
lut_values = np.array(table % (2 << 63), dtype=np.uint64)
map_shape = ()
map_values = None
else:
individual_table_size = len(tables[0][0])
lut_shape = (len(tables), individual_table_size)
map_shape = self.node.output.shape
lut_values = np.zeros(lut_shape, dtype=np.uint64)
map_values = np.zeros(map_shape, dtype=np.intp)
for i, (table, indices) in enumerate(tables):
assert_that(len(table) == individual_table_size)
lut_values[i, :] = table
for index in indices:
map_values[index] = i
lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx))
lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx)
lut = arith.ConstantOp(lut_type, lut_attr).result
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
pred = self.preds[variable_input_index]
if self.one_of_the_inputs_is_a_tensor:
if len(tables) == 1:
result = fhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
else:
assert_that(map_shape != ())
assert_that(map_values is not None)
index_type = IndexType.parse("index")
map_type = RankedTensorType.get(map_shape, index_type)
map_attr = DenseElementsAttr.get(map_values, context=self.ctx, type=index_type)
result = fhelinalg.ApplyMappedLookupTableEintOp(
resulting_type,
pred,
lut,
arith.ConstantOp(map_type, map_attr).result,
).result
else:
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
return result

View File

@@ -0,0 +1,172 @@
"""
Declaration of various functions and constants related to MLIR conversion.
"""
import math
from collections import defaultdict, deque
from copy import deepcopy
from itertools import product
from typing import Any, DefaultDict, List, Optional, Tuple, Union, cast
import numpy as np
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Node, Operation
MAXIMUM_BIT_WIDTH = 8
class HashableNdarray:
"""
HashableNdarray class, to use numpy arrays in dictionaries.
"""
array: np.ndarray
def __init__(self, array: np.ndarray):
self.array = array
def __eq__(self, other: object) -> bool:
return isinstance(other, HashableNdarray) and np.array_equal(self.array, other.array)
def __hash__(self) -> int:
return hash(self.array.tobytes())
def flood_replace_none_values(table: list):
"""
Use flooding algorithm to replace `None` values.
Args:
table (list):
the list in which there are `None` values that need to be replaced
with copies of the closest non `None` data from the list
"""
assert_that(any(value is not None for value in table))
not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None)
while not_none_values_idx:
current_idx = not_none_values_idx.popleft()
current_value = table[current_idx]
previous_idx = current_idx - 1
next_idx = current_idx + 1
if previous_idx >= 0 and table[previous_idx] is None:
table[previous_idx] = deepcopy(current_value)
not_none_values_idx.append(previous_idx)
if next_idx < len(table) and table[next_idx] is None:
table[next_idx] = deepcopy(current_value)
not_none_values_idx.append(next_idx)
assert_that(all(value is not None for value in table))
def construct_table(node: Node, preds: List[Node]) -> List[Any]:
"""
Construct the lookup table for an Operation.Generic node.
Args:
node (Node):
Operation.Generic to construct the table
preds (List[Node]):
ordered predecessors to `node`
Returns:
List[Any]:
lookup table corresponding to `node` and its input value
"""
variable_input_index = -1
for index, pred in enumerate(preds):
if pred.operation != Operation.Constant:
variable_input_index = index
break
assert_that(variable_input_index != -1)
variable_input_dtype = node.inputs[variable_input_index].dtype
variable_input_shape = node.inputs[variable_input_index].shape
assert_that(isinstance(variable_input_dtype, Integer))
variable_input_dtype = cast(Integer, variable_input_dtype)
inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds]
table: List[Optional[Union[np.bool_, np.integer, np.floating, np.ndarray]]] = []
for value in range(variable_input_dtype.min(), variable_input_dtype.max() + 1):
try:
inputs[variable_input_index] = np.ones(variable_input_shape, dtype=np.int64) * value
table.append(node(*inputs))
except Exception: # pylint: disable=broad-except
# here we try our best to fill the table
# if it fails, we append None and let flooding algoritm replace None values below
table.append(None)
flood_replace_none_values(table)
return table
def construct_deduplicated_tables(
node: Node,
preds: List[Node],
) -> Tuple[Tuple[np.ndarray, List[Tuple[int, ...]]], ...]:
"""
Construct lookup tables for each cell of the input for an Operation.Generic node.
Args:
node (Node):
Operation.Generic to construct the table
preds (List[Node]):
ordered predecessors to `node`
Returns:
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
tuple containing tuples of 2 for
- constructed table
- list of indices of the input that use the constructed table
e.g.,
.. code-block:: python
(
(np.array([3, 1, 2, 4]), [(1, 0), (2, 1)]),
(np.array([5, 8, 6, 7]), [(0, 0), (0, 1), (1, 1), (2, 0)]),
)
means the lookup on 3x2 input will result in
.. code-block:: python
[ [5, 8, 6, 7][input[0, 0]] , [5, 8, 6, 7][input[0, 1]] ]
[ [3, 1, 2, 4][input[1, 0]] , [5, 8, 6, 7][input[1, 1]] ]
[ [5, 8, 6, 7][input[2, 0]] , [3, 1, 2, 4][input[2, 1]] ]
"""
node_complete_table = np.concatenate(
tuple(np.expand_dims(array, -1) for array in construct_table(node, preds)),
axis=-1,
)
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
tables_to_cell_idx: DefaultDict[HashableNdarray, List[Tuple[int, ...]]] = defaultdict(list)
idx: Tuple[int, ...]
all_idx_set = set()
for idx in all_cells_idx:
hashable_array = HashableNdarray(node_complete_table[idx])
tables_to_cell_idx[hashable_array].append(idx)
all_idx_set.add(idx)
assert_that(len(all_idx_set) == math.prod(node_complete_table.shape[:-1]))
return tuple(
(hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items()
)