mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: implement mlir module
This commit is contained in:
6
concrete/numpy/mlir/__init__.py
Normal file
6
concrete/numpy/mlir/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Declaration of `concrete.numpy.mlir` namespace.
|
||||
"""
|
||||
|
||||
from .graph_converter import GraphConverter
|
||||
from .node_converter import NodeConverter
|
||||
380
concrete/numpy/mlir/graph_converter.py
Normal file
380
concrete/numpy/mlir/graph_converter.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Declaration of `GraphConverter` class.
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member,no-name-in-module
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
import concrete.lang as concretelang
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from mlir.dialects import builtin
|
||||
from mlir.ir import Context, InsertionPoint, Location, Module
|
||||
|
||||
from ..dtypes import Integer, SignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import ClearScalar
|
||||
from .node_converter import NodeConverter
|
||||
from .utils import MAXIMUM_BIT_WIDTH
|
||||
|
||||
# pylint: enable=no-member,no-name-in-module
|
||||
|
||||
|
||||
class GraphConverter:
|
||||
"""
|
||||
GraphConverter class, to convert computation graphs to their MLIR equivalent.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _check_node_convertibility(graph: Graph, node: Node) -> Optional[str]:
|
||||
"""
|
||||
Check node convertibility to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph of the node
|
||||
|
||||
node (Node):
|
||||
node to be checked
|
||||
|
||||
Returns:
|
||||
Optional[str]:
|
||||
None if node is convertible to MLIR, the reason for inconvertibility otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-return-statements
|
||||
|
||||
inputs = node.inputs
|
||||
output = node.output
|
||||
|
||||
if node.operation == Operation.Constant:
|
||||
assert_that(len(inputs) == 0)
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer constants are supported"
|
||||
|
||||
elif node.operation == Operation.Input:
|
||||
assert_that(len(inputs) == 1)
|
||||
assert_that(inputs[0] == output)
|
||||
if not isinstance(output.dtype, Integer) or output.dtype.is_signed:
|
||||
return "only unsigned integer inputs are supported"
|
||||
|
||||
else:
|
||||
assert_that(node.operation == Operation.Generic)
|
||||
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer operations are supported"
|
||||
|
||||
name = node.properties["name"]
|
||||
|
||||
if name == "add":
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
elif name == "concatenate":
|
||||
if not all(input.is_encrypted for input in inputs):
|
||||
return "only all encrypted concatenate is supported"
|
||||
|
||||
elif name == "conv2d":
|
||||
assert_that(len(inputs) == 2 or len(inputs) == 3)
|
||||
if not (inputs[0].is_encrypted and inputs[1].is_clear):
|
||||
return "only conv2d with encrypted input and clear weight is supported"
|
||||
|
||||
elif name == "dot":
|
||||
assert_that(len(inputs) == 2)
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only dot product between encrypted and clear is supported"
|
||||
|
||||
elif name == "index.static":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted indexing supported"
|
||||
|
||||
elif name == "matmul":
|
||||
assert_that(len(inputs) == 2)
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only matrix multiplication between encrypted and clear is supported"
|
||||
|
||||
elif name == "multiply":
|
||||
assert_that(len(inputs) == 2)
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only multiplication between encrypted and clear is supported"
|
||||
|
||||
elif name == "negative":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted negation is supported"
|
||||
|
||||
elif name == "reshape":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted reshape is supported"
|
||||
|
||||
elif name == "subtract":
|
||||
assert_that(len(inputs) == 2)
|
||||
if not (inputs[0].is_clear and inputs[1].is_encrypted):
|
||||
return "only subtraction of encrypted from clear is supported"
|
||||
|
||||
elif name == "sum":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only encrypted sum is supported"
|
||||
|
||||
else:
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, pred in enumerate(graph.ordered_preds_of(node))
|
||||
if not pred.operation == Operation.Constant
|
||||
]
|
||||
if len(variable_input_indices) != 1:
|
||||
return "only single input table lookups are supported"
|
||||
|
||||
if all(input.is_clear for input in inputs):
|
||||
return "one of the operands must be encrypted"
|
||||
|
||||
return None
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-return-statements
|
||||
|
||||
@staticmethod
|
||||
def _check_graph_convertibility(graph: Graph):
|
||||
"""
|
||||
Check graph convertibility to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to be checked
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if `graph` is not convertible to MLIR
|
||||
"""
|
||||
|
||||
offending_nodes = {}
|
||||
|
||||
if len(graph.output_nodes) > 1:
|
||||
offending_nodes.update(
|
||||
{
|
||||
node: ["only a single output is supported"]
|
||||
for node in graph.output_nodes.values()
|
||||
}
|
||||
)
|
||||
|
||||
if len(offending_nodes) == 0:
|
||||
for node in graph.graph.nodes:
|
||||
if (reason := GraphConverter._check_node_convertibility(graph, node)) is not None:
|
||||
offending_nodes[node] = [reason]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
"Function you are trying to compile cannot be converted to MLIR\n\n"
|
||||
+ graph.format(highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _update_bit_widths(graph: Graph):
|
||||
"""
|
||||
Update bit-widths in a computation graph to be convertible to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to be updated
|
||||
"""
|
||||
|
||||
offending_nodes: Dict[Node, List[str]] = {}
|
||||
|
||||
max_bit_width = 0
|
||||
for node in graph.graph.nodes:
|
||||
dtype = node.output.dtype
|
||||
assert_that(isinstance(dtype, Integer))
|
||||
|
||||
current_node_bit_width = (
|
||||
dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width
|
||||
)
|
||||
max_bit_width = max(max_bit_width, current_node_bit_width)
|
||||
|
||||
if current_node_bit_width > MAXIMUM_BIT_WIDTH:
|
||||
offending_nodes[node] = [
|
||||
f"only up to {MAXIMUM_BIT_WIDTH}-bit integers are supported"
|
||||
]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
"Function you are trying to compile cannot be converted to MLIR:\n\n"
|
||||
+ graph.format(highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
for node in graph.graph.nodes:
|
||||
for value in node.inputs + [node.output]:
|
||||
dtype = value.dtype
|
||||
assert_that(isinstance(dtype, Integer))
|
||||
dtype.bit_width = max_bit_width + 1 if value.is_clear else max_bit_width
|
||||
|
||||
@staticmethod
|
||||
def _offset_negative_lookup_table_inputs(graph: Graph):
|
||||
"""
|
||||
Offset negative table lookup inputs to be convertible to MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to apply offset
|
||||
"""
|
||||
|
||||
# ugly hack to add an offset before entering a TLU
|
||||
# if its variable input node has a signed output.
|
||||
# this makes hardcoded assumptions about the way bit widths are handled in MLIR.
|
||||
# this does not update the TLU input values to allow for proper table generation.
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic:
|
||||
if not node.converted_to_table_lookup:
|
||||
continue
|
||||
|
||||
variable_input_index = -1
|
||||
|
||||
preds = graph.ordered_preds_of(node)
|
||||
for index, pred in enumerate(preds):
|
||||
if pred.operation != Operation.Constant:
|
||||
variable_input_index = index
|
||||
break
|
||||
|
||||
variable_input_node = preds[variable_input_index]
|
||||
|
||||
variable_input_value = variable_input_node.output
|
||||
variable_input_dtype = variable_input_value.dtype
|
||||
|
||||
assert_that(isinstance(variable_input_dtype, Integer))
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
|
||||
if not variable_input_dtype.is_signed:
|
||||
continue
|
||||
|
||||
variable_input_bit_width = variable_input_dtype.bit_width
|
||||
offset_constant_dtype = SignedInteger(variable_input_bit_width + 1)
|
||||
|
||||
offset_constant = Node.constant(abs(variable_input_dtype.min()))
|
||||
offset_constant.output.dtype = offset_constant_dtype
|
||||
|
||||
add_offset = Node.generic(
|
||||
"add",
|
||||
[variable_input_value, ClearScalar(offset_constant_dtype)],
|
||||
variable_input_value,
|
||||
np.add,
|
||||
)
|
||||
|
||||
nx_graph.remove_edge(variable_input_node, node)
|
||||
|
||||
nx_graph.add_edge(variable_input_node, add_offset, input_idx=0)
|
||||
nx_graph.add_edge(offset_constant, add_offset, input_idx=1)
|
||||
|
||||
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
|
||||
|
||||
@staticmethod
|
||||
def convert(graph: Graph) -> str:
|
||||
"""
|
||||
Convert a computation graph to its corresponding MLIR representation.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to be converted
|
||||
|
||||
Returns:
|
||||
str:
|
||||
textual MLIR representation corresponding to `graph`
|
||||
"""
|
||||
|
||||
graph = deepcopy(graph)
|
||||
|
||||
GraphConverter._check_graph_convertibility(graph)
|
||||
GraphConverter._update_bit_widths(graph)
|
||||
GraphConverter._offset_negative_lookup_table_inputs(graph)
|
||||
|
||||
# There are no tensor +*- scalar operations in the compiler
|
||||
# But such operations are used commonly, so we need to support them
|
||||
# So, we implemented some workarounds (pull request #970)
|
||||
# Once we have native support, this workaround shall be removed (issue #837)
|
||||
# (most changes in #970 shall be reverted)
|
||||
|
||||
# { node1: "%arg0", node2: "%0", node3: "%1" }
|
||||
nodes_to_mlir_names: Dict[Node, str] = {}
|
||||
|
||||
# { "%arg0": "i5", "%0": "tensor<2x3x!FHE.eint<4>>" }
|
||||
mlir_names_to_mlir_types: Dict[str, str] = {}
|
||||
|
||||
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
concretelang.register_dialects(ctx)
|
||||
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
parameters = [
|
||||
NodeConverter.value_to_mlir_type(ctx, input_node.output)
|
||||
for input_node in graph.ordered_inputs()
|
||||
]
|
||||
|
||||
@builtin.FuncOp.from_py_func(*parameters)
|
||||
def main(*arg):
|
||||
ir_to_mlir = {}
|
||||
for arg_num, node in graph.input_nodes.items():
|
||||
ir_to_mlir[node] = arg[arg_num]
|
||||
|
||||
mlir_name = f"%arg{arg_num}"
|
||||
nodes_to_mlir_names[node] = mlir_name
|
||||
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])
|
||||
|
||||
for node in nx.topological_sort(graph.graph):
|
||||
if node.operation == Operation.Input:
|
||||
continue
|
||||
|
||||
preds = [ir_to_mlir[pred] for pred in graph.ordered_preds_of(node)]
|
||||
node_converter = NodeConverter(
|
||||
ctx,
|
||||
graph,
|
||||
node,
|
||||
preds,
|
||||
nodes_to_mlir_names,
|
||||
mlir_names_to_mlir_types,
|
||||
scalar_to_1d_tensor_conversion_hacks,
|
||||
)
|
||||
ir_to_mlir[node] = node_converter.convert()
|
||||
|
||||
results = (ir_to_mlir[output_node] for output_node in graph.ordered_outputs())
|
||||
return results
|
||||
|
||||
module_lines_after_hacks_are_applied = []
|
||||
for line in str(module).split("\n"):
|
||||
mlir_name = line.split("=")[0].strip()
|
||||
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
continue
|
||||
|
||||
to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name]
|
||||
for arg_name in to_be_replaced:
|
||||
new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}"
|
||||
mlir_type = mlir_names_to_mlir_types[arg_name]
|
||||
|
||||
hack_line = (
|
||||
f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>"
|
||||
)
|
||||
module_lines_after_hacks_are_applied.append(hack_line)
|
||||
|
||||
line = line.replace(arg_name, new_name)
|
||||
|
||||
new_arg_types = []
|
||||
|
||||
arg_types = line.split(":")[1].split("->")[0].strip()[1:-1]
|
||||
for arg in arg_types.split(", "):
|
||||
if arg.startswith("tensor"):
|
||||
new_arg_types.append(arg)
|
||||
else:
|
||||
new_arg_types.append(f"tensor<1x{arg}>")
|
||||
|
||||
line = line.replace(arg_types, ", ".join(new_arg_types))
|
||||
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
|
||||
return "\n".join(module_lines_after_hacks_are_applied).strip()
|
||||
800
concrete/numpy/mlir/node_converter.py
Normal file
800
concrete/numpy/mlir/node_converter.py
Normal file
@@ -0,0 +1,800 @@
|
||||
"""
|
||||
Declaration of `NodeConverter` class.
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member,no-name-in-module
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from concrete.lang.dialects import fhe, fhelinalg
|
||||
from concrete.lang.dialects.fhe import EncryptedIntegerType
|
||||
from mlir.dialects import arith, linalg, tensor
|
||||
from mlir.ir import (
|
||||
ArrayAttr,
|
||||
Attribute,
|
||||
BoolAttr,
|
||||
Context,
|
||||
DenseElementsAttr,
|
||||
IndexType,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
OpResult,
|
||||
RankedTensorType,
|
||||
Type,
|
||||
)
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import Value
|
||||
from .utils import construct_deduplicated_tables
|
||||
|
||||
# pylint: enable=no-member,no-name-in-module
|
||||
|
||||
|
||||
class NodeConverter:
|
||||
"""
|
||||
NodeConverter class, to convert computation graph nodes to their MLIR equivalent.
|
||||
"""
|
||||
|
||||
ctx: Context
|
||||
graph: Graph
|
||||
node: Node
|
||||
preds: List[OpResult]
|
||||
|
||||
all_of_the_inputs_are_encrypted: bool
|
||||
all_of_the_inputs_are_tensors: bool
|
||||
one_of_the_inputs_is_a_tensor: bool
|
||||
|
||||
nodes_to_mlir_names: Dict[Node, str]
|
||||
mlir_names_to_mlir_types: Dict[str, str]
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
|
||||
|
||||
@staticmethod
|
||||
def value_to_mlir_type(ctx: Context, value: Value) -> Type:
|
||||
"""
|
||||
Convert a `Value` to its corresponding MLIR `Type`.
|
||||
|
||||
Args:
|
||||
ctx (Context):
|
||||
MLIR context to perform the conversion
|
||||
|
||||
value (Value):
|
||||
value to convert
|
||||
|
||||
Returns:
|
||||
Type:
|
||||
MLIR `Type` corresponding to `value`
|
||||
"""
|
||||
|
||||
dtype = value.dtype
|
||||
|
||||
if isinstance(dtype, Integer):
|
||||
if value.is_encrypted:
|
||||
result = EncryptedIntegerType.get(ctx, dtype.bit_width)
|
||||
else:
|
||||
result = IntegerType.get_signless(dtype.bit_width)
|
||||
|
||||
return result if value.is_scalar else RankedTensorType.get(value.shape, result)
|
||||
|
||||
# the branch above is always taken due to compatibility checks
|
||||
# still, it's a good idea to raise an appropriate error, just in case
|
||||
|
||||
raise ValueError(f"{value} cannot be converted to MLIR") # pragma: no cover
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: Context,
|
||||
graph: Graph,
|
||||
node: Node,
|
||||
preds: List[OpResult],
|
||||
nodes_to_mlir_names: Dict[OpResult, str],
|
||||
mlir_names_to_mlir_types: Dict[str, str],
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.graph = graph
|
||||
self.node = node
|
||||
self.preds = preds
|
||||
|
||||
self.all_of_the_inputs_are_encrypted = True
|
||||
self.all_of_the_inputs_are_tensors = True
|
||||
self.one_of_the_inputs_is_a_tensor = False
|
||||
|
||||
for inp in node.inputs:
|
||||
if not inp.is_encrypted:
|
||||
self.all_of_the_inputs_are_encrypted = False
|
||||
|
||||
if inp.is_scalar:
|
||||
self.all_of_the_inputs_are_tensors = False
|
||||
else:
|
||||
self.one_of_the_inputs_is_a_tensor = True
|
||||
|
||||
self.nodes_to_mlir_names = nodes_to_mlir_names
|
||||
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
|
||||
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
|
||||
|
||||
def convert(self) -> OpResult:
|
||||
"""
|
||||
Convert a node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
if self.node.operation == Operation.Constant:
|
||||
result = self.convert_constant()
|
||||
else:
|
||||
assert_that(self.node.operation == Operation.Generic)
|
||||
|
||||
name = self.node.properties["name"]
|
||||
|
||||
if name == "add":
|
||||
result = self.convert_add()
|
||||
|
||||
elif name == "concatenate":
|
||||
result = self.convert_concat()
|
||||
|
||||
elif name == "conv2d":
|
||||
result = self.convert_conv2d()
|
||||
|
||||
elif name == "dot":
|
||||
result = self.convert_dot()
|
||||
|
||||
elif name == "index.static":
|
||||
result = self.convert_static_indexing()
|
||||
|
||||
elif name == "matmul":
|
||||
result = self.convert_matmul()
|
||||
|
||||
elif name == "multiply":
|
||||
result = self.convert_mul()
|
||||
|
||||
elif name == "negative":
|
||||
result = self.convert_neg()
|
||||
|
||||
elif name == "reshape":
|
||||
result = self.convert_reshape()
|
||||
|
||||
elif name == "subtract":
|
||||
result = self.convert_sub()
|
||||
|
||||
elif name == "sum":
|
||||
result = self.convert_sum()
|
||||
|
||||
else:
|
||||
result = self.convert_tlu()
|
||||
|
||||
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
|
||||
self.nodes_to_mlir_names[self.node] = mlir_name
|
||||
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
|
||||
|
||||
if self.node.operation == Operation.Generic:
|
||||
name = self.node.properties["name"]
|
||||
if name in ["add", "dot", "multiply", "subtract"]:
|
||||
if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors:
|
||||
to_be_converted = []
|
||||
for pred in self.graph.ordered_preds_of(self.node):
|
||||
if pred.output.is_scalar:
|
||||
to_be_converted.append(self.nodes_to_mlir_names[pred])
|
||||
self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted
|
||||
|
||||
return result
|
||||
|
||||
# pylint: enable=too-many-branches
|
||||
|
||||
def convert_add(self) -> OpResult:
|
||||
"""
|
||||
Convert "add" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.AddEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.AddEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
if self.node.inputs[0].is_clear:
|
||||
preds = preds[::-1]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.AddEintIntOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.AddEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_concat(self) -> OpResult:
|
||||
"""
|
||||
Convert "concatenate" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
axis = self.node.properties["kwargs"].get("axis", 0)
|
||||
|
||||
if axis is not None:
|
||||
if axis < 0:
|
||||
axis += len(self.node.inputs[0].shape)
|
||||
return fhelinalg.ConcatOp(
|
||||
resulting_type,
|
||||
self.preds,
|
||||
IntegerAttr.get(IntegerType.get_signless(64), axis),
|
||||
).result
|
||||
|
||||
flattened_preds = []
|
||||
for pred, input_value in zip(self.preds, self.node.inputs):
|
||||
input_shape = input_value.shape
|
||||
input_size = np.prod(input_shape)
|
||||
|
||||
flattened_pred_type = RankedTensorType.get(
|
||||
[input_size],
|
||||
NodeConverter.value_to_mlir_type(
|
||||
self.ctx,
|
||||
Value(input_value.dtype, shape=(), is_encrypted=input_value.is_encrypted),
|
||||
),
|
||||
)
|
||||
flattened_pred = linalg.TensorCollapseShapeOp(
|
||||
flattened_pred_type,
|
||||
pred,
|
||||
ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get(
|
||||
[
|
||||
IntegerAttr.get(IndexType.parse("index"), i)
|
||||
for i in range(len(input_shape))
|
||||
]
|
||||
)
|
||||
]
|
||||
),
|
||||
).result
|
||||
flattened_preds.append(flattened_pred)
|
||||
|
||||
return fhelinalg.ConcatOp(
|
||||
resulting_type,
|
||||
flattened_preds,
|
||||
IntegerAttr.get(IntegerType.get_signless(64), 0),
|
||||
).result
|
||||
|
||||
def convert_constant(self) -> OpResult:
|
||||
"""
|
||||
Convert Operation.Constant node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
data = self.node()
|
||||
|
||||
if self.node.output.is_scalar:
|
||||
attr = IntegerAttr.get(resulting_type, data)
|
||||
else:
|
||||
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
|
||||
# provided by LLVM
|
||||
|
||||
# what should have been used is `DenseElementsAttr` but it's impossible to assign
|
||||
# custom bit-widths using it (e.g., uint5)
|
||||
|
||||
# since we couldn't create a `DenseElementsAttr` with a custom bit width using
|
||||
# the python api we use `Attribute.parse` to let the underlying library do it by itself
|
||||
|
||||
attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}")
|
||||
|
||||
return arith.ConstantOp(resulting_type, attr).result
|
||||
|
||||
def convert_conv2d(self) -> OpResult:
|
||||
"""
|
||||
Convert "conv2d" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
integer_type = IntegerType.get_signless(64, context=self.ctx)
|
||||
|
||||
strides = DenseElementsAttr.get(
|
||||
np.array(list(self.node.properties["kwargs"]["strides"]), dtype=np.uint64),
|
||||
type=integer_type,
|
||||
context=self.ctx,
|
||||
)
|
||||
dilations = DenseElementsAttr.get(
|
||||
np.array(list(self.node.properties["kwargs"]["dilations"]), dtype=np.uint64),
|
||||
type=integer_type,
|
||||
context=self.ctx,
|
||||
)
|
||||
pads = DenseElementsAttr.get(
|
||||
np.array(list(self.node.properties["kwargs"]["pads"]), dtype=np.uint64),
|
||||
type=integer_type,
|
||||
context=self.ctx,
|
||||
)
|
||||
|
||||
has_bias = len(self.node.inputs) == 3
|
||||
if not has_bias:
|
||||
preds.append(None)
|
||||
|
||||
return fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result
|
||||
|
||||
def convert_dot(self) -> OpResult:
|
||||
"""
|
||||
Convert "dot" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
if self.node.inputs[0].is_clear:
|
||||
preds = preds[::-1]
|
||||
|
||||
if self.all_of_the_inputs_are_tensors:
|
||||
# numpy.dot(x, y) where x and y are both vectors = regular dot product
|
||||
result = fhelinalg.Dot(resulting_type, *preds).result
|
||||
|
||||
elif not self.one_of_the_inputs_is_a_tensor:
|
||||
# numpy.dot(x, y) where x and y are both scalars = x * y
|
||||
result = fhe.MulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
else:
|
||||
# numpy.dot(x, y) where one of x or y is a scalar and the other one is a vector = x * y
|
||||
result = fhelinalg.MulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_matmul(self) -> OpResult:
|
||||
"""Convert a MatMul node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
if self.node.output.shape == ():
|
||||
if self.node.inputs[0].is_clear:
|
||||
preds = preds[::-1]
|
||||
result = fhelinalg.Dot(resulting_type, *preds).result
|
||||
|
||||
elif self.node.inputs[0].is_clear:
|
||||
result = fhelinalg.MatMulIntEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhelinalg.MatMulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_mul(self) -> OpResult:
|
||||
"""
|
||||
Convert "multiply" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
if self.node.inputs[0].is_clear:
|
||||
preds = preds[::-1]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.MulEintIntOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.MulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_neg(self) -> OpResult:
|
||||
"""
|
||||
Convert "negative" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
pred = self.preds[0]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.NegEintOp(resulting_type, pred).result
|
||||
else:
|
||||
result = fhe.NegEintOp(resulting_type, pred).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_reshape(self) -> OpResult:
|
||||
"""
|
||||
Convert "reshape" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
input_shape = self.node.inputs[0].shape
|
||||
output_shape = self.node.output.shape
|
||||
|
||||
pred = self.preds[0]
|
||||
if input_shape == output_shape:
|
||||
return pred
|
||||
|
||||
# we can either collapse or expand, which changes the number of dimensions
|
||||
# this is a limitation of the current compiler, it will be improved in the future (#1060)
|
||||
can_be_converted_directly = len(input_shape) != len(output_shape)
|
||||
|
||||
reassociation: List[List[int]] = []
|
||||
if can_be_converted_directly:
|
||||
if len(output_shape) == 1:
|
||||
# output is 1 dimensional so collapse every dimension into the same dimension
|
||||
reassociation.append(list(range(len(input_shape))))
|
||||
else:
|
||||
# input is m dimensional
|
||||
# output is n dimensional
|
||||
# and m is different from n
|
||||
|
||||
# we don't want to duplicate code, so we forget about input and output,
|
||||
# and we focus on smaller shape and bigger shape
|
||||
|
||||
smaller_shape, bigger_shape = (
|
||||
(output_shape, input_shape)
|
||||
if len(output_shape) < len(input_shape)
|
||||
else (input_shape, output_shape)
|
||||
)
|
||||
s_index, b_index = 0, 0
|
||||
|
||||
# now we will figure out how to group the bigger shape to get the smaller shape
|
||||
# think of the algorithm below as
|
||||
# keep merging the dimensions of the bigger shape
|
||||
# until we have a match on the smaller shape
|
||||
# then try to match the next dimension of the smaller shape
|
||||
# if all dimensions of the smaller shape is matched
|
||||
# we can convert it
|
||||
|
||||
group = []
|
||||
size = 1
|
||||
while s_index < len(smaller_shape) and b_index < len(bigger_shape):
|
||||
# dimension `b_index` of `bigger_shape` belongs to current group
|
||||
group.append(b_index)
|
||||
|
||||
# and current group has `size * bigger_shape[b_index]` elements now
|
||||
size *= bigger_shape[b_index]
|
||||
|
||||
# if current group size matches the dimension `s_index` of `smaller_shape`
|
||||
if size == smaller_shape[s_index]:
|
||||
# we finalize this group and reset everything
|
||||
size = 1
|
||||
reassociation.append(group)
|
||||
group = []
|
||||
|
||||
# now try to match the next dimension of `smaller_shape`
|
||||
s_index += 1
|
||||
|
||||
# now process the next dimension of `bigger_shape`
|
||||
b_index += 1
|
||||
|
||||
# handle the case where bigger shape has proceeding 1s
|
||||
# e.g., (5,) -> (5, 1)
|
||||
while b_index < len(bigger_shape) and bigger_shape[b_index] == 1:
|
||||
reassociation[-1].append(b_index)
|
||||
b_index += 1
|
||||
|
||||
# if not all dimensions of both shapes are processed exactly
|
||||
if s_index != len(smaller_shape) or b_index != len(bigger_shape):
|
||||
# we cannot convert
|
||||
can_be_converted_directly = False
|
||||
|
||||
index_type = IndexType.parse("index")
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
if can_be_converted_directly:
|
||||
reassociation_attr = ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, dimension) for dimension in group])
|
||||
for group in reassociation
|
||||
]
|
||||
)
|
||||
if len(output_shape) < len(input_shape):
|
||||
return linalg.TensorCollapseShapeOp(resulting_type, pred, reassociation_attr).result
|
||||
return linalg.TensorExpandShapeOp(resulting_type, pred, reassociation_attr).result
|
||||
|
||||
flattened_type = NodeConverter.value_to_mlir_type(
|
||||
self.ctx,
|
||||
Value(
|
||||
dtype=self.node.inputs[0].dtype,
|
||||
shape=(np.prod(input_shape),),
|
||||
is_encrypted=self.node.inputs[0].is_encrypted,
|
||||
),
|
||||
)
|
||||
flattened_result = linalg.TensorCollapseShapeOp(
|
||||
flattened_type,
|
||||
pred,
|
||||
ArrayAttr.get(
|
||||
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(input_shape))])]
|
||||
),
|
||||
).result
|
||||
|
||||
return linalg.TensorExpandShapeOp(
|
||||
resulting_type,
|
||||
flattened_result,
|
||||
ArrayAttr.get(
|
||||
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(output_shape))])]
|
||||
),
|
||||
).result
|
||||
|
||||
def convert_static_indexing(self) -> OpResult:
|
||||
"""
|
||||
Convert "index.static" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
input_value = self.node.inputs[0]
|
||||
input_shape = input_value.shape
|
||||
|
||||
index = list(self.node.properties["attributes"]["index"])
|
||||
index_type = IndexType.parse("index")
|
||||
|
||||
while len(index) < input_value.ndim:
|
||||
index.append(slice(None, None, None))
|
||||
|
||||
output_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
if len(index) == len(input_shape) and all(isinstance(i, (int, np.integer)) for i in index):
|
||||
indices = []
|
||||
for value, dimension_size in zip(index, input_shape):
|
||||
value = int(value)
|
||||
attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size)
|
||||
indices.append(arith.ConstantOp(index_type, attr).result)
|
||||
return tensor.ExtractOp(output_type, self.preds[0], indices).result
|
||||
|
||||
offsets = []
|
||||
sizes = []
|
||||
strides = []
|
||||
|
||||
destroyed_dimensions = []
|
||||
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
size = np.zeros(dimension_size)[indexing_element].shape[0]
|
||||
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
|
||||
offset = (
|
||||
(
|
||||
indexing_element.start
|
||||
if indexing_element.start >= 0
|
||||
else indexing_element.start + dimension_size
|
||||
)
|
||||
if isinstance(indexing_element.start, int)
|
||||
else (0 if stride > 0 else dimension_size - 1)
|
||||
)
|
||||
|
||||
else:
|
||||
destroyed_dimensions.append(dimension)
|
||||
size = 1
|
||||
stride = 1
|
||||
offset = int(
|
||||
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
|
||||
)
|
||||
|
||||
offsets.append(offset)
|
||||
sizes.append(size)
|
||||
strides.append(stride)
|
||||
|
||||
if len(destroyed_dimensions) == 0:
|
||||
return tensor.ExtractSliceOp(
|
||||
output_type,
|
||||
self.preds[0],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
|
||||
).result
|
||||
|
||||
output_value = self.node.output
|
||||
|
||||
intermediate_shape = list(output_value.shape)
|
||||
for dimension in destroyed_dimensions:
|
||||
intermediate_shape.insert(dimension, 1)
|
||||
|
||||
intermediate = tensor.ExtractSliceOp(
|
||||
RankedTensorType.get(
|
||||
intermediate_shape,
|
||||
NodeConverter.value_to_mlir_type(
|
||||
self.ctx,
|
||||
Value(output_value.dtype, shape=(), is_encrypted=output_value.is_encrypted),
|
||||
),
|
||||
),
|
||||
self.preds[0],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
|
||||
).result
|
||||
|
||||
reassociaton = []
|
||||
|
||||
current_intermediate_dimension = 0
|
||||
for _ in range(len(output_value.shape)):
|
||||
indices = [current_intermediate_dimension]
|
||||
while current_intermediate_dimension in destroyed_dimensions:
|
||||
current_intermediate_dimension += 1
|
||||
indices.append(current_intermediate_dimension)
|
||||
|
||||
reassociaton.append(indices)
|
||||
current_intermediate_dimension += 1
|
||||
while current_intermediate_dimension < len(intermediate_shape):
|
||||
reassociaton[-1].append(current_intermediate_dimension)
|
||||
current_intermediate_dimension += 1
|
||||
|
||||
return linalg.TensorCollapseShapeOp(
|
||||
output_type,
|
||||
intermediate,
|
||||
ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get(
|
||||
[IntegerAttr.get(index_type, index) for index in indices],
|
||||
)
|
||||
for indices in reassociaton
|
||||
],
|
||||
),
|
||||
).result
|
||||
|
||||
def convert_sub(self) -> OpResult:
|
||||
"""
|
||||
Convert "subtract" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = fhelinalg.SubIntEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = fhe.SubIntEintOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_sum(self) -> OpResult:
|
||||
"""
|
||||
Convert "sum" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
axes = self.node.properties["kwargs"].get("axis", [])
|
||||
keep_dims = self.node.properties["kwargs"].get("keepdims", False)
|
||||
|
||||
if isinstance(axes, int):
|
||||
axes = [axes]
|
||||
elif isinstance(axes, tuple):
|
||||
axes = list(axes)
|
||||
|
||||
input_dimensions = self.node.inputs[0].ndim
|
||||
for i, axis in enumerate(axes):
|
||||
if axis < 0:
|
||||
axes[i] += input_dimensions
|
||||
|
||||
return fhelinalg.SumOp(
|
||||
resulting_type,
|
||||
self.preds[0],
|
||||
ArrayAttr.get([IntegerAttr.get(IntegerType.get_signless(64), axis) for axis in axes]),
|
||||
BoolAttr.get(keep_dims),
|
||||
).result
|
||||
|
||||
def convert_tlu(self) -> OpResult:
|
||||
"""
|
||||
Convert Operation.Generic node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
variable_input_index = -1
|
||||
|
||||
preds = self.graph.ordered_preds_of(self.node)
|
||||
for i, pred in enumerate(preds):
|
||||
if pred.operation != Operation.Constant:
|
||||
variable_input_index = i
|
||||
break
|
||||
|
||||
assert_that(variable_input_index != -1)
|
||||
|
||||
tables = construct_deduplicated_tables(self.node, preds)
|
||||
assert_that(len(tables) > 0)
|
||||
|
||||
lut_shape: Tuple[int, ...] = ()
|
||||
map_shape: Tuple[int, ...] = ()
|
||||
|
||||
if len(tables) == 1:
|
||||
table = tables[0][0]
|
||||
|
||||
# The reduction on 63b is to avoid problems like doing a TLU of
|
||||
# the form T[j] = 2<<j, for j which is supposed to be 7b as per
|
||||
# constraint of the compiler, while in practice, it is a small
|
||||
# value. Reducing on 64b was not ok for some reason
|
||||
lut_shape = (len(table),)
|
||||
lut_values = np.array(table % (2 << 63), dtype=np.uint64)
|
||||
|
||||
map_shape = ()
|
||||
map_values = None
|
||||
else:
|
||||
individual_table_size = len(tables[0][0])
|
||||
|
||||
lut_shape = (len(tables), individual_table_size)
|
||||
map_shape = self.node.output.shape
|
||||
|
||||
lut_values = np.zeros(lut_shape, dtype=np.uint64)
|
||||
map_values = np.zeros(map_shape, dtype=np.intp)
|
||||
|
||||
for i, (table, indices) in enumerate(tables):
|
||||
assert_that(len(table) == individual_table_size)
|
||||
lut_values[i, :] = table
|
||||
for index in indices:
|
||||
map_values[index] = i
|
||||
|
||||
lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx))
|
||||
lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx)
|
||||
lut = arith.ConstantOp(lut_type, lut_attr).result
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
pred = self.preds[variable_input_index]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
if len(tables) == 1:
|
||||
result = fhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
else:
|
||||
assert_that(map_shape != ())
|
||||
assert_that(map_values is not None)
|
||||
|
||||
index_type = IndexType.parse("index")
|
||||
map_type = RankedTensorType.get(map_shape, index_type)
|
||||
map_attr = DenseElementsAttr.get(map_values, context=self.ctx, type=index_type)
|
||||
|
||||
result = fhelinalg.ApplyMappedLookupTableEintOp(
|
||||
resulting_type,
|
||||
pred,
|
||||
lut,
|
||||
arith.ConstantOp(map_type, map_attr).result,
|
||||
).result
|
||||
else:
|
||||
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
|
||||
return result
|
||||
172
concrete/numpy/mlir/utils.py
Normal file
172
concrete/numpy/mlir/utils.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Declaration of various functions and constants related to MLIR conversion.
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, DefaultDict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Node, Operation
|
||||
|
||||
MAXIMUM_BIT_WIDTH = 8
|
||||
|
||||
|
||||
class HashableNdarray:
|
||||
"""
|
||||
HashableNdarray class, to use numpy arrays in dictionaries.
|
||||
"""
|
||||
|
||||
array: np.ndarray
|
||||
|
||||
def __init__(self, array: np.ndarray):
|
||||
self.array = array
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, HashableNdarray) and np.array_equal(self.array, other.array)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.array.tobytes())
|
||||
|
||||
|
||||
def flood_replace_none_values(table: list):
|
||||
"""
|
||||
Use flooding algorithm to replace `None` values.
|
||||
|
||||
Args:
|
||||
table (list):
|
||||
the list in which there are `None` values that need to be replaced
|
||||
with copies of the closest non `None` data from the list
|
||||
"""
|
||||
|
||||
assert_that(any(value is not None for value in table))
|
||||
|
||||
not_none_values_idx = deque(idx for idx, value in enumerate(table) if value is not None)
|
||||
while not_none_values_idx:
|
||||
|
||||
current_idx = not_none_values_idx.popleft()
|
||||
current_value = table[current_idx]
|
||||
|
||||
previous_idx = current_idx - 1
|
||||
next_idx = current_idx + 1
|
||||
|
||||
if previous_idx >= 0 and table[previous_idx] is None:
|
||||
table[previous_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(previous_idx)
|
||||
|
||||
if next_idx < len(table) and table[next_idx] is None:
|
||||
table[next_idx] = deepcopy(current_value)
|
||||
not_none_values_idx.append(next_idx)
|
||||
|
||||
assert_that(all(value is not None for value in table))
|
||||
|
||||
|
||||
def construct_table(node: Node, preds: List[Node]) -> List[Any]:
|
||||
"""
|
||||
Construct the lookup table for an Operation.Generic node.
|
||||
|
||||
Args:
|
||||
node (Node):
|
||||
Operation.Generic to construct the table
|
||||
|
||||
preds (List[Node]):
|
||||
ordered predecessors to `node`
|
||||
|
||||
Returns:
|
||||
List[Any]:
|
||||
lookup table corresponding to `node` and its input value
|
||||
"""
|
||||
|
||||
variable_input_index = -1
|
||||
for index, pred in enumerate(preds):
|
||||
if pred.operation != Operation.Constant:
|
||||
variable_input_index = index
|
||||
break
|
||||
assert_that(variable_input_index != -1)
|
||||
|
||||
variable_input_dtype = node.inputs[variable_input_index].dtype
|
||||
variable_input_shape = node.inputs[variable_input_index].shape
|
||||
|
||||
assert_that(isinstance(variable_input_dtype, Integer))
|
||||
variable_input_dtype = cast(Integer, variable_input_dtype)
|
||||
|
||||
inputs: List[Any] = [pred() if pred.operation == Operation.Constant else None for pred in preds]
|
||||
|
||||
table: List[Optional[Union[np.bool_, np.integer, np.floating, np.ndarray]]] = []
|
||||
for value in range(variable_input_dtype.min(), variable_input_dtype.max() + 1):
|
||||
try:
|
||||
inputs[variable_input_index] = np.ones(variable_input_shape, dtype=np.int64) * value
|
||||
table.append(node(*inputs))
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# here we try our best to fill the table
|
||||
# if it fails, we append None and let flooding algoritm replace None values below
|
||||
table.append(None)
|
||||
|
||||
flood_replace_none_values(table)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def construct_deduplicated_tables(
|
||||
node: Node,
|
||||
preds: List[Node],
|
||||
) -> Tuple[Tuple[np.ndarray, List[Tuple[int, ...]]], ...]:
|
||||
"""
|
||||
Construct lookup tables for each cell of the input for an Operation.Generic node.
|
||||
|
||||
Args:
|
||||
node (Node):
|
||||
Operation.Generic to construct the table
|
||||
|
||||
preds (List[Node]):
|
||||
ordered predecessors to `node`
|
||||
|
||||
Returns:
|
||||
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
|
||||
tuple containing tuples of 2 for
|
||||
- constructed table
|
||||
- list of indices of the input that use the constructed table
|
||||
|
||||
e.g.,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
(
|
||||
(np.array([3, 1, 2, 4]), [(1, 0), (2, 1)]),
|
||||
(np.array([5, 8, 6, 7]), [(0, 0), (0, 1), (1, 1), (2, 0)]),
|
||||
)
|
||||
|
||||
means the lookup on 3x2 input will result in
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[ [5, 8, 6, 7][input[0, 0]] , [5, 8, 6, 7][input[0, 1]] ]
|
||||
[ [3, 1, 2, 4][input[1, 0]] , [5, 8, 6, 7][input[1, 1]] ]
|
||||
[ [5, 8, 6, 7][input[2, 0]] , [3, 1, 2, 4][input[2, 1]] ]
|
||||
"""
|
||||
|
||||
node_complete_table = np.concatenate(
|
||||
tuple(np.expand_dims(array, -1) for array in construct_table(node, preds)),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
|
||||
tables_to_cell_idx: DefaultDict[HashableNdarray, List[Tuple[int, ...]]] = defaultdict(list)
|
||||
|
||||
idx: Tuple[int, ...]
|
||||
all_idx_set = set()
|
||||
for idx in all_cells_idx:
|
||||
hashable_array = HashableNdarray(node_complete_table[idx])
|
||||
tables_to_cell_idx[hashable_array].append(idx)
|
||||
all_idx_set.add(idx)
|
||||
|
||||
assert_that(len(all_idx_set) == math.prod(node_complete_table.shape[:-1]))
|
||||
|
||||
return tuple(
|
||||
(hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items()
|
||||
)
|
||||
Reference in New Issue
Block a user