refactor: remove the old hack to support operations between tensors and scalars

This commit is contained in:
Umut
2022-07-12 16:33:58 +02:00
parent d199358c0f
commit 84bdb65529
2 changed files with 65 additions and 96 deletions

View File

@@ -5,7 +5,7 @@ Declaration of `GraphConverter` class.
# pylint: disable=no-member,no-name-in-module
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, List, Optional, cast
import concrete.lang as concretelang
import networkx as nx
@@ -299,11 +299,63 @@ class GraphConverter:
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
@staticmethod
def _sanitize_signed_inputs(
graph: Graph,
args: List[Any],
ctx: Context,
) -> Tuple[List[Any], List[str]]:
def _tensorize_scalars_for_fhelinalg(graph: Graph):
"""
Tensorize scalars if they are used within fhelinalg operations.
Args:
graph (Graph):
computation graph to update
"""
# pylint: disable=invalid-name
OPS_TO_TENSORIZE = ["add", "dot", "multiply", "subtract"]
# pylint: enable=invalid-name
tensorized_scalars: Dict[Node, Node] = {}
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE:
assert_that(len(node.inputs) == 2)
if set(inp.is_scalar for inp in node.inputs) != {True, False}:
continue
pred_to_tensorize: Optional[Node] = None
pred_to_tensorize_index = 0
preds = graph.ordered_preds_of(node)
for index, pred in enumerate(preds):
if pred.output.is_scalar:
pred_to_tensorize = pred
pred_to_tensorize_index = index
break
assert pred_to_tensorize is not None
tensorized_pred = tensorized_scalars.get(pred_to_tensorize)
if tensorized_pred is None:
tensorized_value = deepcopy(pred_to_tensorize.output)
tensorized_value.shape = (1,)
tensorized_pred = Node.generic(
"array",
[pred_to_tensorize.output],
tensorized_value,
lambda *args: np.array(args),
)
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
tensorized_scalars[pred_to_tensorize] = tensorized_pred
assert tensorized_pred is not None
nx_graph.remove_edge(pred_to_tensorize, node)
nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index)
@staticmethod
def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]:
"""
Apply table lookup to signed inputs in the beginning of evaluation to sanitize them.
@@ -341,8 +393,6 @@ class GraphConverter:
"""
sanitized_args = []
arg_mlir_names = []
for i, arg in enumerate(args):
input_node = graph.input_nodes[i]
input_value = input_node.output
@@ -370,14 +420,10 @@ class GraphConverter:
sanitized = fhelinalg.ApplyLookupTableEintOp(resulting_type, arg, lut).result
sanitized_args.append(sanitized)
mlir_name = str(sanitized).replace("Value(", "").split("=", maxsplit=1)[0].strip()
else:
sanitized_args.append(arg)
mlir_name = f"%arg{i}"
arg_mlir_names.append(mlir_name)
return sanitized_args, arg_mlir_names
return sanitized_args
@staticmethod
def convert(graph: Graph, virtual: bool = False) -> str:
@@ -404,21 +450,7 @@ class GraphConverter:
GraphConverter._update_bit_widths(graph)
GraphConverter._offset_negative_lookup_table_inputs(graph)
# There are no tensor +*- scalar operations in the compiler
# But such operations are used commonly, so we need to support them
# So, we implemented some workarounds (pull request #970)
# Once we have native support, this workaround shall be removed (issue #837)
# (most changes in #970 shall be reverted)
# { node1: "%arg0", node2: "%0", node3: "%1" }
nodes_to_mlir_names: Dict[Node, str] = {}
# { "%arg0": "i5", "%0": "tensor<2x3x!FHE.eint<4>>" }
mlir_names_to_mlir_types: Dict[str, str] = {}
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
# { "%0": "tensor.from_elements ..." } == we need to convert the part after "=" for %0
direct_replacements: Dict[str, str] = {}
@@ -435,20 +467,12 @@ class GraphConverter:
@builtin.FuncOp.from_py_func(*parameters)
def main(*args):
sanitized_args, arg_mlir_names = GraphConverter._sanitize_signed_inputs(
graph,
args,
ctx,
)
sanitized_args = GraphConverter._sanitize_signed_inputs(graph, args, ctx)
ir_to_mlir = {}
for arg_num, node in graph.input_nodes.items():
ir_to_mlir[node] = sanitized_args[arg_num]
mlir_name = arg_mlir_names[arg_num]
nodes_to_mlir_names[node] = mlir_name
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])
for node in nx.topological_sort(graph.graph):
if node.operation == Operation.Input:
continue
@@ -459,9 +483,6 @@ class GraphConverter:
graph,
node,
preds,
nodes_to_mlir_names,
mlir_names_to_mlir_types,
scalar_to_1d_tensor_conversion_hacks,
direct_replacements,
)
ir_to_mlir[node] = node_converter.convert()
@@ -472,39 +493,11 @@ class GraphConverter:
module_lines_after_hacks_are_applied = []
for line in str(module).split("\n"):
mlir_name = line.split("=")[0].strip()
if mlir_name in direct_replacements:
new_value = direct_replacements[mlir_name]
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
continue
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
if mlir_name not in direct_replacements:
module_lines_after_hacks_are_applied.append(line)
continue
to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name]
for arg_name in to_be_replaced:
new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}"
mlir_type = mlir_names_to_mlir_types[arg_name]
hack_line = (
f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>"
)
module_lines_after_hacks_are_applied.append(hack_line)
line = line.replace(arg_name, new_name)
new_arg_types = []
arg_types = line.split(":")[1].split("->")[0].strip()[1:-1]
for arg in arg_types.split(", "):
if arg.startswith("tensor"):
new_arg_types.append(arg)
else:
new_arg_types.append(f"tensor<1x{arg}>")
line = line.replace(arg_types, ", ".join(new_arg_types))
module_lines_after_hacks_are_applied.append(line)
new_value = direct_replacements[mlir_name]
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
return "\n".join(module_lines_after_hacks_are_applied).strip()

View File

@@ -51,9 +51,6 @@ class NodeConverter:
all_of_the_inputs_are_tensors: bool
one_of_the_inputs_is_a_tensor: bool
nodes_to_mlir_names: Dict[Node, str]
mlir_names_to_mlir_types: Dict[str, str]
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
direct_replacements: Dict[str, str]
# pylint: enable=too-many-instance-attributes
@@ -115,9 +112,6 @@ class NodeConverter:
graph: Graph,
node: Node,
preds: List[OpResult],
nodes_to_mlir_names: Dict[OpResult, str],
mlir_names_to_mlir_types: Dict[str, str],
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
direct_replacements: Dict[str, str],
):
self.ctx = ctx
@@ -138,9 +132,6 @@ class NodeConverter:
else:
self.one_of_the_inputs_is_a_tensor = True
self.nodes_to_mlir_names = nodes_to_mlir_names
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
self.direct_replacements = direct_replacements
def convert(self) -> OpResult:
@@ -216,21 +207,6 @@ class NodeConverter:
assert_that(self.node.converted_to_table_lookup)
result = self._convert_tlu()
mlir_name = NodeConverter.mlir_name(result)
self.nodes_to_mlir_names[self.node] = mlir_name
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
if self.node.operation == Operation.Generic:
name = self.node.properties["name"]
if name in ["add", "dot", "multiply", "subtract"]:
if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors:
to_be_converted = []
for pred in self.graph.ordered_preds_of(self.node):
if pred.output.is_scalar:
to_be_converted.append(self.nodes_to_mlir_names[pred])
self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted
return result
# pylint: enable=too-many-branches
@@ -284,7 +260,7 @@ class NodeConverter:
pred_names = []
for pred, value in zip(preds, self.node.inputs):
if value.is_encrypted:
if value.is_encrypted or self.node.output.is_clear:
pred_names.append(NodeConverter.mlir_name(pred))
continue