refactor(mlir): implement support for operations between tensors and scalars using string processing hacks

This commit is contained in:
Umut
2021-11-22 13:08:44 +03:00
parent a85e4e591a
commit 1d77816aa3
3 changed files with 91 additions and 30 deletions

View File

@@ -5,7 +5,7 @@
# pylint: disable=no-name-in-module
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any, Dict, List
import networkx as nx
import zamalang
@@ -13,7 +13,7 @@ from mlir.dialects import builtin
from mlir.ir import Context, InsertionPoint, Location, Module
from ..operator_graph import OPGraph
from ..representation.intermediate import Input
from ..representation.intermediate import Input, IntermediateNode
from .conversion_helpers import value_to_mlir_type
from .node_converter import IntermediateNodeConverter
@@ -35,6 +35,15 @@ class OPGraphConverter(ABC):
additional_conversion_info = self._generate_additional_info_dict(op_graph)
# { node1: "%arg0", node2: "%0", node3: "%1" }
nodes_to_mlir_names: Dict[IntermediateNode, str] = {}
# { "%arg0": "i5", "%0": "tensor<2x3x!HLFHE.eint<4>>" }
mlir_names_to_mlir_types: Dict[str, str] = {}
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
@@ -51,12 +60,24 @@ class OPGraphConverter(ABC):
for arg_num, node in op_graph.input_nodes.items():
ir_to_mlir[node] = arg[arg_num]
mlir_name = f"%arg{arg_num}"
nodes_to_mlir_names[node] = mlir_name
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])
for node in nx.topological_sort(op_graph.graph):
if isinstance(node, Input):
continue
preds = [ir_to_mlir[pred] for pred in op_graph.get_ordered_preds(node)]
node_converter = IntermediateNodeConverter(ctx, op_graph, node, preds)
node_converter = IntermediateNodeConverter(
ctx,
op_graph,
node,
preds,
nodes_to_mlir_names,
mlir_names_to_mlir_types,
scalar_to_1d_tensor_conversion_hacks,
)
ir_to_mlir[node] = node_converter.convert(additional_conversion_info)
results = (
@@ -64,7 +85,39 @@ class OPGraphConverter(ABC):
)
return results
return str(module)
module_lines_after_hacks_are_applied = []
for line in str(module).split("\n"):
mlir_name = line.split("=")[0].strip()
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
module_lines_after_hacks_are_applied.append(line)
continue
to_be_replaced = scalar_to_1d_tensor_conversion_hacks[mlir_name]
for arg_name in to_be_replaced:
new_name = f"%hack_{mlir_name.replace('%', '')}_{arg_name.replace('%', '')}"
mlir_type = mlir_names_to_mlir_types[arg_name]
hack_line = (
f" {new_name} = tensor.from_elements {arg_name} : tensor<1x{mlir_type}>"
)
module_lines_after_hacks_are_applied.append(hack_line)
line = line.replace(arg_name, new_name)
new_arg_types = []
arg_types = line.split(":")[1].split("->")[0].strip()[1:-1]
for arg in arg_types.split(", "):
if arg.startswith("tensor"):
new_arg_types.append(arg)
else:
new_arg_types.append(f"tensor<1x{arg}>")
line = line.replace(arg_types, ", ".join(new_arg_types))
module_lines_after_hacks_are_applied.append(line)
return "\n".join(module_lines_after_hacks_are_applied)
@staticmethod
@abstractmethod

View File

@@ -50,8 +50,19 @@ class IntermediateNodeConverter:
all_of_the_inputs_are_tensors: bool
one_of_the_inputs_is_a_tensor: bool
nodes_to_mlir_names: Dict[IntermediateNode, str]
mlir_names_to_mlir_types: Dict[str, str]
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
def __init__(
self, ctx: Context, op_graph: OPGraph, node: IntermediateNode, preds: List[OpResult]
self,
ctx: Context,
op_graph: OPGraph,
node: IntermediateNode,
preds: List[OpResult],
nodes_to_mlir_names: Dict[OpResult, str],
mlir_names_to_mlir_types: Dict[str, str],
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
):
self.ctx = ctx
self.op_graph = op_graph
@@ -75,6 +86,10 @@ class IntermediateNodeConverter:
# this branch is not covered as there are only TensorValues for now
self.all_of_the_inputs_are_tensors = False
self.nodes_to_mlir_names = nodes_to_mlir_names
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
def convert(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
"""Convert an intermediate node to its corresponding MLIR representation.
@@ -111,6 +126,20 @@ class IntermediateNodeConverter:
# this branch is not covered as unsupported opeations fail on check mlir compatibility
raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet")
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
self.nodes_to_mlir_names[self.node] = mlir_name
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
if isinstance(self.node, (Add, Mul, Sub, Dot)):
if self.one_of_the_inputs_is_a_tensor and not self.all_of_the_inputs_are_tensors:
to_be_converted = []
for (pred, output) in self.op_graph.get_ordered_inputs_of(self.node):
inp = pred.outputs[output]
if isinstance(inp, TensorValue) and inp.is_scalar:
to_be_converted.append(self.nodes_to_mlir_names[pred])
self.scalar_to_1d_tensor_conversion_hacks[mlir_name] = to_be_converted
return result
def convert_add(self) -> OpResult:

View File

@@ -736,11 +736,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
[
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: x + 1,
{
@@ -759,7 +754,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[7, 2],
[3, 6],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
@@ -780,11 +774,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[5, 6],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x, y: x + y,
{
@@ -811,7 +800,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[8, 3],
[4, 7],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x, y: x + y,
@@ -844,11 +832,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[5, 9],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: 100 - x,
{
@@ -867,7 +850,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[94, 99],
[98, 95],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
@@ -888,11 +870,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[8, 25],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: x * 2,
{
@@ -911,7 +888,6 @@ def test_compile_and_run_correctness__for_prog_with_tlu(
[12, 2],
[4, 10],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
@@ -994,7 +970,10 @@ def test_compile_and_run_tensor_correctness(
default_compilation_configuration,
)
numpy_test_input = (numpy.array(item, dtype=numpy.uint8) for item in test_input)
numpy_test_input = (
item if isinstance(item, int) else numpy.array(item, dtype=numpy.uint8)
for item in test_input
)
assert numpy.array_equal(
circuit.run(*numpy_test_input),
numpy.array(expected_output, dtype=numpy.uint8),