mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(mlir): implement support for operations between tensors and scalars using string processing hacks
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import networkx as nx
|
||||
import zamalang
|
||||
@@ -13,7 +13,7 @@ from mlir.dialects import builtin
|
||||
from mlir.ir import Context, InsertionPoint, Location, Module
|
||||
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Input
|
||||
from ..representation.intermediate import Input, IntermediateNode
|
||||
from .conversion_helpers import value_to_mlir_type
|
||||
from .node_converter import IntermediateNodeConverter
|
||||
|
||||
@@ -35,6 +35,15 @@ class OPGraphConverter(ABC):
|
||||
|
||||
additional_conversion_info = self._generate_additional_info_dict(op_graph)
|
||||
|
||||
# { node1: "%arg0", node2: "%0", node3: "%1" }
|
||||
nodes_to_mlir_names: Dict[IntermediateNode, str] = {}
|
||||
|
||||
# { "%arg0": "i5", "%0": "tensor<2x3x!HLFHE.eint<4>>" }
|
||||
mlir_names_to_mlir_types: Dict[str, str] = {}
|
||||
|
||||
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
zamalang.register_dialects(ctx)
|
||||
|
||||
@@ -51,12 +60,24 @@ class OPGraphConverter(ABC):
|
||||
for arg_num, node in op_graph.input_nodes.items():
|
||||
ir_to_mlir[node] = arg[arg_num]
|
||||
|
||||
mlir_name = f"%arg{arg_num}"
|
||||
nodes_to_mlir_names[node] = mlir_name
|
||||
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])
|
||||
|
||||
for node in nx.topological_sort(op_graph.graph):
|
||||
if isinstance(node, Input):
|
||||
continue
|
||||
|
||||
preds = [ir_to_mlir[pred] for pred in op_graph.get_ordered_preds(node)]
|
||||
node_converter = IntermediateNodeConverter(ctx, op_graph, node, preds)
|
||||
node_converter = IntermediateNodeConverter(
|
||||
ctx,
|
||||
op_graph,
|
||||
node,
|
||||
preds,
|
||||
nodes_to_mlir_names,
|
||||
mlir_names_to_mlir_types,
|
||||
scalar_to_1d_tensor_conversion_hacks,
|
||||
)
|
||||
ir_to_mlir[node] = node_converter.convert(additional_conversion_info)
|
||||
|
||||
results = (
|
||||
@@ -64,7 +85,39 @@ class OPGraphConverter(ABC):
|
||||
)
|
||||
return results
|
||||
|
||||
return str(module)
|
||||
module_lines_after_hacks_are_applied = []
|
||||
for line in str(module).split("\n"):
|
||||
mlir_name = line.split("=")[0].strip()
|
||||
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
continue
|
||||
|
||||
to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name]
|
||||
for arg_name in to_be_replaced:
|
||||
new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}"
|
||||
mlir_type = mlir_names_to_mlir_types[arg_name]
|
||||
|
||||
hack_line = (
|
||||
f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>"
|
||||
)
|
||||
module_lines_after_hacks_are_applied.append(hack_line)
|
||||
|
||||
line = line.replace(arg_name, new_name)
|
||||
|
||||
new_arg_types = []
|
||||
|
||||
arg_types = line.split(":")[1].split("->")[0].strip()[1:-1]
|
||||
for arg in arg_types.split(", "):
|
||||
if arg.startswith("tensor"):
|
||||
new_arg_types.append(arg)
|
||||
else:
|
||||
new_arg_types.append(f"tensor<1x{arg}>")
|
||||
|
||||
line = line.replace(arg_types, ", ".join(new_arg_types))
|
||||
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
|
||||
return "\n".join(module_lines_after_hacks_are_applied)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
|
||||
@@ -50,8 +50,19 @@ class IntermediateNodeConverter:
|
||||
all_of_the_inputs_are_tensors: bool
|
||||
one_of_the_inputs_is_a_tensor: bool
|
||||
|
||||
nodes_to_mlir_names: Dict[IntermediateNode, str]
|
||||
mlir_names_to_mlir_types: Dict[str, str]
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
|
||||
|
||||
def __init__(
|
||||
self, ctx: Context, op_graph: OPGraph, node: IntermediateNode, preds: List[OpResult]
|
||||
self,
|
||||
ctx: Context,
|
||||
op_graph: OPGraph,
|
||||
node: IntermediateNode,
|
||||
preds: List[OpResult],
|
||||
nodes_to_mlir_names: Dict[OpResult, str],
|
||||
mlir_names_to_mlir_types: Dict[str, str],
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.op_graph = op_graph
|
||||
@@ -75,6 +86,10 @@ class IntermediateNodeConverter:
|
||||
# this branch is not covered as there are only TensorValues for now
|
||||
self.all_of_the_inputs_are_tensors = False
|
||||
|
||||
self.nodes_to_mlir_names = nodes_to_mlir_names
|
||||
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
|
||||
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
|
||||
|
||||
def convert(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
|
||||
"""Convert an intermediate node to its corresponding MLIR representation.
|
||||
|
||||
@@ -111,6 +126,20 @@ class IntermediateNodeConverter:
|
||||
# this branch is not covered as unsupported opeations fail on check mlir compatibility
|
||||
raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet")
|
||||
|
||||
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
|
||||
self.nodes_to_mlir_names[self.node] = mlir_name
|
||||
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
|
||||
|
||||
if isinstance(self.node, (Add, Mul, Sub, Dot)):
|
||||
if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors:
|
||||
to_be_converted = []
|
||||
for (pred, output) in self.op_graph.get_ordered_inputs_of(self.node):
|
||||
inp = pred.outputs[output]
|
||||
if isinstance(inp, TensorValue) and inp.is_scalar:
|
||||
to_be_converted.append(self.nodes_to_mlir_names[pred])
|
||||
self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted
|
||||
|
||||
return result
|
||||
|
||||
def convert_add(self) -> OpResult:
|
||||
|
||||
@@ -736,11 +736,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
[
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{
|
||||
@@ -759,7 +754,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[7, 2],
|
||||
[3, 6],
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
|
||||
@@ -780,11 +774,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[5, 6],
|
||||
],
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{
|
||||
@@ -811,7 +800,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[8, 3],
|
||||
[4, 7],
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
@@ -844,11 +832,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[5, 9],
|
||||
],
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: 100 - x,
|
||||
{
|
||||
@@ -867,7 +850,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[94, 99],
|
||||
[98, 95],
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
|
||||
@@ -888,11 +870,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[8, 25],
|
||||
],
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: x * 2,
|
||||
{
|
||||
@@ -911,7 +888,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
[12, 2],
|
||||
[4, 10],
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
|
||||
@@ -994,7 +970,10 @@ def test_compile_and_run_tensor_correctness(
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
numpy_test_input = (numpy.array(item, dtype=numpy.uint8) for item in test_input)
|
||||
numpy_test_input = (
|
||||
item if isinstance(item, int) else numpy.array(item, dtype=numpy.uint8)
|
||||
for item in test_input
|
||||
)
|
||||
assert numpy.array_equal(
|
||||
circuit.run(*numpy_test_input),
|
||||
numpy.array(expected_output, dtype=numpy.uint8),
|
||||
|
||||
Reference in New Issue
Block a user