mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: remove the old hack to support operations between tensors and scalars
This commit is contained in:
@@ -5,7 +5,7 @@ Declaration of `GraphConverter` class.
|
||||
# pylint: disable=no-member,no-name-in-module
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import concrete.lang as concretelang
|
||||
import networkx as nx
|
||||
@@ -299,11 +299,63 @@ class GraphConverter:
|
||||
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_signed_inputs(
|
||||
graph: Graph,
|
||||
args: List[Any],
|
||||
ctx: Context,
|
||||
) -> Tuple[List[Any], List[str]]:
|
||||
def _tensorize_scalars_for_fhelinalg(graph: Graph):
|
||||
"""
|
||||
Tensorize scalars if they are used within fhelinalg operations.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to update
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
OPS_TO_TENSORIZE = ["add", "dot", "multiply", "subtract"]
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
tensorized_scalars: Dict[Node, Node] = {}
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] in OPS_TO_TENSORIZE:
|
||||
assert_that(len(node.inputs) == 2)
|
||||
|
||||
if set(inp.is_scalar for inp in node.inputs) != {True, False}:
|
||||
continue
|
||||
|
||||
pred_to_tensorize: Optional[Node] = None
|
||||
pred_to_tensorize_index = 0
|
||||
|
||||
preds = graph.ordered_preds_of(node)
|
||||
for index, pred in enumerate(preds):
|
||||
if pred.output.is_scalar:
|
||||
pred_to_tensorize = pred
|
||||
pred_to_tensorize_index = index
|
||||
break
|
||||
|
||||
assert pred_to_tensorize is not None
|
||||
|
||||
tensorized_pred = tensorized_scalars.get(pred_to_tensorize)
|
||||
if tensorized_pred is None:
|
||||
tensorized_value = deepcopy(pred_to_tensorize.output)
|
||||
tensorized_value.shape = (1,)
|
||||
|
||||
tensorized_pred = Node.generic(
|
||||
"array",
|
||||
[pred_to_tensorize.output],
|
||||
tensorized_value,
|
||||
lambda *args: np.array(args),
|
||||
)
|
||||
nx_graph.add_edge(pred_to_tensorize, tensorized_pred, input_idx=0)
|
||||
|
||||
tensorized_scalars[pred_to_tensorize] = tensorized_pred
|
||||
|
||||
assert tensorized_pred is not None
|
||||
|
||||
nx_graph.remove_edge(pred_to_tensorize, node)
|
||||
nx_graph.add_edge(tensorized_pred, node, input_idx=pred_to_tensorize_index)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_signed_inputs(graph: Graph, args: List[Any], ctx: Context) -> List[Any]:
|
||||
"""
|
||||
Apply table lookup to signed inputs in the beginning of evaluation to sanitize them.
|
||||
|
||||
@@ -341,8 +393,6 @@ class GraphConverter:
|
||||
"""
|
||||
|
||||
sanitized_args = []
|
||||
arg_mlir_names = []
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
input_node = graph.input_nodes[i]
|
||||
input_value = input_node.output
|
||||
@@ -370,14 +420,10 @@ class GraphConverter:
|
||||
sanitized = fhelinalg.ApplyLookupTableEintOp(resulting_type, arg, lut).result
|
||||
|
||||
sanitized_args.append(sanitized)
|
||||
mlir_name = str(sanitized).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
else:
|
||||
sanitized_args.append(arg)
|
||||
mlir_name = f"%arg{i}"
|
||||
|
||||
arg_mlir_names.append(mlir_name)
|
||||
|
||||
return sanitized_args, arg_mlir_names
|
||||
return sanitized_args
|
||||
|
||||
@staticmethod
|
||||
def convert(graph: Graph, virtual: bool = False) -> str:
|
||||
@@ -404,21 +450,7 @@ class GraphConverter:
|
||||
|
||||
GraphConverter._update_bit_widths(graph)
|
||||
GraphConverter._offset_negative_lookup_table_inputs(graph)
|
||||
|
||||
# There are no tensor +*- scalar operations in the compiler
|
||||
# But such operations are used commonly, so we need to support them
|
||||
# So, we implemented some workarounds (pull request #970)
|
||||
# Once we have native support, this workaround shall be removed (issue #837)
|
||||
# (most changes in #970 shall be reverted)
|
||||
|
||||
# { node1: "%arg0", node2: "%0", node3: "%1" }
|
||||
nodes_to_mlir_names: Dict[Node, str] = {}
|
||||
|
||||
# { "%arg0": "i5", "%0": "tensor<2x3x!FHE.eint<4>>" }
|
||||
mlir_names_to_mlir_types: Dict[str, str] = {}
|
||||
|
||||
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
|
||||
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
|
||||
|
||||
# { "%0": "tensor.from_elements ..." } == we need to convert the part after "=" for %0
|
||||
direct_replacements: Dict[str, str] = {}
|
||||
@@ -435,20 +467,12 @@ class GraphConverter:
|
||||
|
||||
@builtin.FuncOp.from_py_func(*parameters)
|
||||
def main(*args):
|
||||
sanitized_args, arg_mlir_names = GraphConverter._sanitize_signed_inputs(
|
||||
graph,
|
||||
args,
|
||||
ctx,
|
||||
)
|
||||
sanitized_args = GraphConverter._sanitize_signed_inputs(graph, args, ctx)
|
||||
|
||||
ir_to_mlir = {}
|
||||
for arg_num, node in graph.input_nodes.items():
|
||||
ir_to_mlir[node] = sanitized_args[arg_num]
|
||||
|
||||
mlir_name = arg_mlir_names[arg_num]
|
||||
nodes_to_mlir_names[node] = mlir_name
|
||||
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])
|
||||
|
||||
for node in nx.topological_sort(graph.graph):
|
||||
if node.operation == Operation.Input:
|
||||
continue
|
||||
@@ -459,9 +483,6 @@ class GraphConverter:
|
||||
graph,
|
||||
node,
|
||||
preds,
|
||||
nodes_to_mlir_names,
|
||||
mlir_names_to_mlir_types,
|
||||
scalar_to_1d_tensor_conversion_hacks,
|
||||
direct_replacements,
|
||||
)
|
||||
ir_to_mlir[node] = node_converter.convert()
|
||||
@@ -472,39 +493,11 @@ class GraphConverter:
|
||||
module_lines_after_hacks_are_applied = []
|
||||
for line in str(module).split("\n"):
|
||||
mlir_name = line.split("=")[0].strip()
|
||||
|
||||
if mlir_name in direct_replacements:
|
||||
new_value = direct_replacements[mlir_name]
|
||||
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
|
||||
continue
|
||||
|
||||
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
|
||||
if mlir_name not in direct_replacements:
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
continue
|
||||
|
||||
to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name]
|
||||
for arg_name in to_be_replaced:
|
||||
new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}"
|
||||
mlir_type = mlir_names_to_mlir_types[arg_name]
|
||||
|
||||
hack_line = (
|
||||
f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>"
|
||||
)
|
||||
module_lines_after_hacks_are_applied.append(hack_line)
|
||||
|
||||
line = line.replace(arg_name, new_name)
|
||||
|
||||
new_arg_types = []
|
||||
|
||||
arg_types = line.split(":")[1].split("->")[0].strip()[1:-1]
|
||||
for arg in arg_types.split(", "):
|
||||
if arg.startswith("tensor"):
|
||||
new_arg_types.append(arg)
|
||||
else:
|
||||
new_arg_types.append(f"tensor<1x{arg}>")
|
||||
|
||||
line = line.replace(arg_types, ", ".join(new_arg_types))
|
||||
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
new_value = direct_replacements[mlir_name]
|
||||
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
|
||||
|
||||
return "\n".join(module_lines_after_hacks_are_applied).strip()
|
||||
|
||||
@@ -51,9 +51,6 @@ class NodeConverter:
|
||||
all_of_the_inputs_are_tensors: bool
|
||||
one_of_the_inputs_is_a_tensor: bool
|
||||
|
||||
nodes_to_mlir_names: Dict[Node, str]
|
||||
mlir_names_to_mlir_types: Dict[str, str]
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
|
||||
direct_replacements: Dict[str, str]
|
||||
|
||||
# pylint: enable=too-many-instance-attributes
|
||||
@@ -115,9 +112,6 @@ class NodeConverter:
|
||||
graph: Graph,
|
||||
node: Node,
|
||||
preds: List[OpResult],
|
||||
nodes_to_mlir_names: Dict[OpResult, str],
|
||||
mlir_names_to_mlir_types: Dict[str, str],
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
|
||||
direct_replacements: Dict[str, str],
|
||||
):
|
||||
self.ctx = ctx
|
||||
@@ -138,9 +132,6 @@ class NodeConverter:
|
||||
else:
|
||||
self.one_of_the_inputs_is_a_tensor = True
|
||||
|
||||
self.nodes_to_mlir_names = nodes_to_mlir_names
|
||||
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
|
||||
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
|
||||
self.direct_replacements = direct_replacements
|
||||
|
||||
def convert(self) -> OpResult:
|
||||
@@ -216,21 +207,6 @@ class NodeConverter:
|
||||
assert_that(self.node.converted_to_table_lookup)
|
||||
result = self._convert_tlu()
|
||||
|
||||
mlir_name = NodeConverter.mlir_name(result)
|
||||
|
||||
self.nodes_to_mlir_names[self.node] = mlir_name
|
||||
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
|
||||
|
||||
if self.node.operation == Operation.Generic:
|
||||
name = self.node.properties["name"]
|
||||
if name in ["add", "dot", "multiply", "subtract"]:
|
||||
if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors:
|
||||
to_be_converted = []
|
||||
for pred in self.graph.ordered_preds_of(self.node):
|
||||
if pred.output.is_scalar:
|
||||
to_be_converted.append(self.nodes_to_mlir_names[pred])
|
||||
self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted
|
||||
|
||||
return result
|
||||
|
||||
# pylint: enable=too-many-branches
|
||||
@@ -284,7 +260,7 @@ class NodeConverter:
|
||||
|
||||
pred_names = []
|
||||
for pred, value in zip(preds, self.node.inputs):
|
||||
if value.is_encrypted:
|
||||
if value.is_encrypted or self.node.output.is_clear:
|
||||
pred_names.append(NodeConverter.mlir_name(pred))
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user