refactor(mlir): generate tables before converting nodes to MLIR

- MLIRConverter becomes an abstract base class
- pass the needed informations in an extra dict to MLIR converters
- NPMLIRConverter handles the specifics of numpy tables generation
This commit is contained in:
Arthur Meyre
2021-10-18 11:02:42 +02:00
parent a8aafcb70a
commit 82688206f7
6 changed files with 89 additions and 28 deletions

View File

@@ -26,7 +26,7 @@ from ..representation.intermediate import Add, Constant, Dot, Mul, Sub, Univaria
from ..values import TensorValue
def add(node, preds, ir_to_mlir_node, ctx):
def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert an addition intermediate node."""
assert_true(len(node.inputs) == 2, "addition should have two inputs")
assert_true(len(node.outputs) == 1, "addition should have a single output")
@@ -70,7 +70,7 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
).result
def sub(node, preds, ir_to_mlir_node, ctx):
def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a subtraction intermediate node."""
assert_true(len(node.inputs) == 2, "subtraction should have two inputs")
assert_true(len(node.outputs) == 1, "subtraction should have a single output")
@@ -94,7 +94,7 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
).result
def mul(node, preds, ir_to_mlir_node, ctx):
def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a multiplication intermediate node."""
assert_true(len(node.inputs) == 2, "multiplication should have two inputs")
assert_true(len(node.outputs) == 1, "multiplication should have a single output")
@@ -123,7 +123,7 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
).result
def constant(node, _, __, ctx):
def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a constant input."""
value = node.outputs[0]
@@ -164,7 +164,7 @@ def constant(node, _, __, ctx):
raise TypeError(f"Don't support {value} constants")
def apply_lut(node, preds, ir_to_mlir_node, ctx):
def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
"""Convert a UnivariateFunction intermediate node."""
assert_true(len(node.inputs) == 1, "LUT should have a single input")
assert_true(len(node.outputs) == 1, "LUT should have a single output")
@@ -181,7 +181,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
x_node = preds[0]
x = ir_to_mlir_node[x_node]
table = node.get_table()
table = additional_conversion_info["tables"][node]
out_dtype = cast(Integer, node.outputs[0].dtype)
# Create table
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
@@ -196,7 +196,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
).result
def dot(node, preds, ir_to_mlir_node, ctx):
def dot(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a dot intermediate node."""
assert_true(len(node.inputs) == 2, "Dot should have two inputs")
assert_true(len(node.outputs) == 1, "Dot should have a single output")

View File

@@ -1,6 +1,7 @@
"""File containing code to convert a DAG containing ir nodes to the compiler opset."""
# pylint: disable=no-name-in-module,no-member
from typing import Tuple, cast
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, cast
import networkx as nx
import zamalang
@@ -22,7 +23,7 @@ from ..operator_graph import OPGraph
from ..representation.intermediate import Input
class MLIRConverter:
class MLIRConverter(ABC):
"""Converter of the common IR to MLIR."""
def __init__(self, conversion_functions: dict) -> None:
@@ -87,6 +88,18 @@ class MLIRConverter:
# unsigned integer are considered signless in the compiler
return IntegerType.get_signless(bit_width)
@staticmethod
@abstractmethod
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
"""Generate the additional_conversion_info dict for the MLIR converter.
Args:
op_graph (OPGraph): the OPGraph for which we need the conversion infos.
Returns:
Dict[str, Any]: The dict with the additional conversion infos.
"""
def common_value_to_mlir_type(self, value: values.BaseValue) -> MLIRType:
"""Convert a common value to its corresponding MLIR Type.
@@ -125,11 +138,16 @@ class MLIRConverter:
"""Convert the graph of IntermediateNode to an MLIR textual representation.
Args:
graph: graph of IntermediateNode to be converted
op_graph (OPGraph): graph of IntermediateNode to be converted
Raises:
NotImplementedError: raised if an unknown node type is encountered.
Returns:
textual MLIR representation
str: textual MLIR representation
"""
additional_conversion_info = self._generate_additional_info_dict(op_graph)
with self.context, Location.unknown():
module = Module.create()
# collect inputs
@@ -162,7 +180,13 @@ class MLIRConverter:
idx_to_pred[data["input_idx"]] = pred
preds = [idx_to_pred[i] for i in range(len(idx_to_pred))]
# convert to mlir
result = mlir_op(node, preds, ir_to_mlir_node, self.context)
result = mlir_op(
node,
preds,
ir_to_mlir_node,
self.context,
additional_conversion_info,
)
ir_to_mlir_node[node] = result
results = (

View File

@@ -13,7 +13,7 @@ from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.debugging import get_printable_graph
from ..common.fhe_circuit import FHECircuit
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
from ..common.mlir.utils import (
check_graph_values_compatibility_with_mlir,
extend_direct_lookup_tables,
@@ -29,6 +29,7 @@ from .np_dtypes_helpers import (
get_base_value_for_numpy_or_python_constant_data,
get_constructor_for_numpy_or_python_constant_data,
)
from .np_mlir_converter import NPMLIRConverter
def numpy_max_func(lhs: Any, rhs: Any) -> Any:
@@ -281,11 +282,8 @@ def _compile_numpy_function_internal(
)
# Convert graph to an MLIR representation
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
# Disable numpy warnings during conversion to avoid issues during TLU generation
with numpy.errstate(all="ignore"):
mlir_result = converter.convert(op_graph)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(op_graph)
# Show MLIR representation if requested
if show_mlir:

View File

@@ -0,0 +1,36 @@
"""Numpy-specific MLIR converter."""
from typing import Any, Dict
import numpy
from ..common.mlir.mlir_converter import MLIRConverter
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import UnivariateFunction
class NPMLIRConverter(MLIRConverter):
"""Numpy-specific MLIR converter."""
@staticmethod
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
"""Generate the additional_conversion_info dict for the MLIR converter.
Args:
op_graph (OPGraph): the OPGraph for which we need the conversion infos.
Returns:
Dict[str, Any]: The dict with the additional conversion infos.
"""
additional_conversion_info = {}
# Disable numpy warnings during conversion to avoid issues during TLU generation
with numpy.errstate(all="ignore"):
additional_conversion_info["tables"] = {
node: node.get_table()
for node in op_graph.graph.nodes()
if isinstance(node, UnivariateFunction)
}
return additional_conversion_info

View File

@@ -63,6 +63,7 @@ def test_fail_tlu_input(input_node):
[None],
None,
None,
None,
)
@@ -84,4 +85,5 @@ def test_fail_tlu_output(input_node):
[None],
None,
None,
None,
)

View File

@@ -10,10 +10,11 @@ from zamalang.dialects import hlfhe
from concrete.common.data_types.integers import Integer
from concrete.common.extensions.table import LookupTable
from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
from concrete.common.values import ClearScalar, EncryptedScalar
from concrete.common.values.tensors import ClearTensor, EncryptedTensor
from concrete.numpy.compile import compile_numpy_function_into_op_graph
from concrete.numpy.np_mlir_converter import NPMLIRConverter
def add(x, y):
@@ -219,7 +220,7 @@ def test_mlir_converter(func, args_dict, args_ranges):
"""Test the conversion to MLIR by calling the parser from the compiler"""
inputset = datagen(*args_ranges)
result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error
compiler.round_trip(mlir_result)
@@ -261,7 +262,7 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
for data in datagen(*args_ranges)
),
)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error
compiler.round_trip(mlir_result)
@@ -281,7 +282,7 @@ def test_mlir_converter_dot_vector_and_constant():
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
)
left_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
left_mlir = left_converter.convert(left_graph)
right_graph = compile_numpy_function_into_op_graph(
@@ -289,7 +290,7 @@ def test_mlir_converter_dot_vector_and_constant():
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
)
right_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
right_mlir = right_converter.convert(right_graph)
# testing that this doesn't raise an error
@@ -300,7 +301,7 @@ def test_mlir_converter_dot_vector_and_constant():
def test_concrete_encrypted_integer_to_mlir_type():
"""Test conversion of EncryptedScalar into MLIR"""
value = EncryptedScalar(Integer(7, is_signed=False))
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
eint = converter.common_value_to_mlir_type(value)
assert eint == hlfhe.EncryptedIntegerType.get(converter.context, 7)
@@ -309,7 +310,7 @@ def test_concrete_encrypted_integer_to_mlir_type():
def test_concrete_clear_integer_to_mlir_type(is_signed):
"""Test conversion of ClearScalar into MLIR"""
value = ClearScalar(Integer(5, is_signed=is_signed))
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context:
int_mlir = converter.common_value_to_mlir_type(value)
if is_signed:
@@ -330,7 +331,7 @@ def test_concrete_clear_integer_to_mlir_type(is_signed):
def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape):
"""Test conversion of ClearTensor into MLIR"""
value = ClearTensor(Integer(5, is_signed=is_signed), shape)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context, Location.unknown():
tensor_mlir = converter.common_value_to_mlir_type(value)
if is_signed:
@@ -355,7 +356,7 @@ def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape):
def test_concrete_encrypted_tensor_integer_to_mlir_type(shape):
"""Test conversion of EncryptedTensor into MLIR"""
value = EncryptedTensor(Integer(6, is_signed=False), shape)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context, Location.unknown():
tensor_mlir = converter.common_value_to_mlir_type(value)
element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6)
@@ -369,7 +370,7 @@ def test_concrete_encrypted_tensor_integer_to_mlir_type(shape):
def test_failing_concrete_to_mlir_type():
"""Test failing conversion of an unsupported type into MLIR"""
value = "random"
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with pytest.raises(TypeError, match=r"can't convert value of type .* to MLIR type"):
converter.common_value_to_mlir_type(value)