From 82688206f75933dec590ca67559e5c8672083aad Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 18 Oct 2021 11:02:42 +0200 Subject: [PATCH] refactor(mlir): generate tables before converting nodes to MLIR - MLIRConverter becomes an abstract base class - pass the needed informations in an extra dict to MLIR converters - NPMLIRConverter handles the specifics of numpy tables generation --- concrete/common/mlir/converters.py | 14 ++++----- concrete/common/mlir/mlir_converter.py | 34 ++++++++++++++++++---- concrete/numpy/compile.py | 10 +++---- concrete/numpy/np_mlir_converter.py | 36 ++++++++++++++++++++++++ tests/common/mlir/test_converters.py | 2 ++ tests/common/mlir/test_mlir_converter.py | 21 +++++++------- 6 files changed, 89 insertions(+), 28 deletions(-) create mode 100644 concrete/numpy/np_mlir_converter.py diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index 74f08cebb..915b9dc4c 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -26,7 +26,7 @@ from ..representation.intermediate import Add, Constant, Dot, Mul, Sub, Univaria from ..values import TensorValue -def add(node, preds, ir_to_mlir_node, ctx): +def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): """Convert an addition intermediate node.""" assert_true(len(node.inputs) == 2, "addition should have two inputs") assert_true(len(node.outputs) == 1, "addition should have a single output") @@ -70,7 +70,7 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx): ).result -def sub(node, preds, ir_to_mlir_node, ctx): +def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): """Convert a subtraction intermediate node.""" assert_true(len(node.inputs) == 2, "subtraction should have two inputs") assert_true(len(node.outputs) == 1, "subtraction should have a single output") @@ -94,7 +94,7 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx): ).result -def mul(node, preds, ir_to_mlir_node, ctx): +def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): """Convert a multiplication intermediate node.""" assert_true(len(node.inputs) == 2, "multiplication should have two inputs") assert_true(len(node.outputs) == 1, "multiplication should have a single output") @@ -123,7 +123,7 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx): ).result -def constant(node, _, __, ctx): +def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=None): """Convert a constant input.""" value = node.outputs[0] @@ -164,7 +164,7 @@ def constant(node, _, __, ctx): raise TypeError(f"Don't support {value} constants") -def apply_lut(node, preds, ir_to_mlir_node, ctx): +def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info): """Convert a UnivariateFunction intermediate node.""" assert_true(len(node.inputs) == 1, "LUT should have a single input") assert_true(len(node.outputs) == 1, "LUT should have a single output") @@ -181,7 +181,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx): x_node = preds[0] x = ir_to_mlir_node[x_node] - table = node.get_table() + table = additional_conversion_info["tables"][node] out_dtype = cast(Integer, node.outputs[0].dtype) # Create table dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx) @@ -196,7 +196,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx): ).result -def dot(node, preds, ir_to_mlir_node, ctx): +def dot(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): """Convert a dot intermediate node.""" assert_true(len(node.inputs) == 2, "Dot should have two inputs") assert_true(len(node.outputs) == 1, "Dot should have a single output") diff --git a/concrete/common/mlir/mlir_converter.py b/concrete/common/mlir/mlir_converter.py index 984fc0fc3..ddb3a8fe2 100644 --- a/concrete/common/mlir/mlir_converter.py +++ b/concrete/common/mlir/mlir_converter.py @@ -1,6 +1,7 @@ """File containing code to convert a DAG containing ir nodes to the compiler opset.""" # pylint: disable=no-name-in-module,no-member -from typing import Tuple, cast +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple, cast import networkx as nx import zamalang @@ -22,7 +23,7 @@ from ..operator_graph import OPGraph from ..representation.intermediate import Input -class MLIRConverter: +class MLIRConverter(ABC): """Converter of the common IR to MLIR.""" def __init__(self, conversion_functions: dict) -> None: @@ -87,6 +88,18 @@ class MLIRConverter: # unsigned integer are considered signless in the compiler return IntegerType.get_signless(bit_width) + @staticmethod + @abstractmethod + def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]: + """Generate the additional_conversion_info dict for the MLIR converter. + + Args: + op_graph (OPGraph): the OPGraph for which we need the conversion infos. + + Returns: + Dict[str, Any]: The dict with the additional conversion infos. + """ + def common_value_to_mlir_type(self, value: values.BaseValue) -> MLIRType: """Convert a common value to its corresponding MLIR Type. @@ -125,11 +138,16 @@ class MLIRConverter: """Convert the graph of IntermediateNode to an MLIR textual representation. Args: - graph: graph of IntermediateNode to be converted + op_graph (OPGraph): graph of IntermediateNode to be converted + + Raises: + NotImplementedError: raised if an unknown node type is encountered. Returns: - textual MLIR representation + str: textual MLIR representation """ + additional_conversion_info = self._generate_additional_info_dict(op_graph) + with self.context, Location.unknown(): module = Module.create() # collect inputs @@ -162,7 +180,13 @@ class MLIRConverter: idx_to_pred[data["input_idx"]] = pred preds = [idx_to_pred[i] for i in range(len(idx_to_pred))] # convert to mlir - result = mlir_op(node, preds, ir_to_mlir_node, self.context) + result = mlir_op( + node, + preds, + ir_to_mlir_node, + self.context, + additional_conversion_info, + ) ir_to_mlir_node[node] = result results = ( diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index 5604c00f8..5e754ddfd 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -13,7 +13,7 @@ from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import Integer from ..common.debugging import get_printable_graph from ..common.fhe_circuit import FHECircuit -from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter +from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS from ..common.mlir.utils import ( check_graph_values_compatibility_with_mlir, extend_direct_lookup_tables, @@ -29,6 +29,7 @@ from .np_dtypes_helpers import ( get_base_value_for_numpy_or_python_constant_data, get_constructor_for_numpy_or_python_constant_data, ) +from .np_mlir_converter import NPMLIRConverter def numpy_max_func(lhs: Any, rhs: Any) -> Any: @@ -281,11 +282,8 @@ def _compile_numpy_function_internal( ) # Convert graph to an MLIR representation - converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - - # Disable numpy warnings during conversion to avoid issues during TLU generation - with numpy.errstate(all="ignore"): - mlir_result = converter.convert(op_graph) + converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + mlir_result = converter.convert(op_graph) # Show MLIR representation if requested if show_mlir: diff --git a/concrete/numpy/np_mlir_converter.py b/concrete/numpy/np_mlir_converter.py new file mode 100644 index 000000000..320a76a4c --- /dev/null +++ b/concrete/numpy/np_mlir_converter.py @@ -0,0 +1,36 @@ +"""Numpy-specific MLIR converter.""" + +from typing import Any, Dict + +import numpy + +from ..common.mlir.mlir_converter import MLIRConverter +from ..common.operator_graph import OPGraph +from ..common.representation.intermediate import UnivariateFunction + + +class NPMLIRConverter(MLIRConverter): + """Numpy-specific MLIR converter.""" + + @staticmethod + def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]: + """Generate the additional_conversion_info dict for the MLIR converter. + + Args: + op_graph (OPGraph): the OPGraph for which we need the conversion infos. + + Returns: + Dict[str, Any]: The dict with the additional conversion infos. + """ + + additional_conversion_info = {} + + # Disable numpy warnings during conversion to avoid issues during TLU generation + with numpy.errstate(all="ignore"): + additional_conversion_info["tables"] = { + node: node.get_table() + for node in op_graph.graph.nodes() + if isinstance(node, UnivariateFunction) + } + + return additional_conversion_info diff --git a/tests/common/mlir/test_converters.py b/tests/common/mlir/test_converters.py index 89cfdfd26..d9f590669 100644 --- a/tests/common/mlir/test_converters.py +++ b/tests/common/mlir/test_converters.py @@ -63,6 +63,7 @@ def test_fail_tlu_input(input_node): [None], None, None, + None, ) @@ -84,4 +85,5 @@ def test_fail_tlu_output(input_node): [None], None, None, + None, ) diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index cfb51c91e..3746febee 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -10,10 +10,11 @@ from zamalang.dialects import hlfhe from concrete.common.data_types.integers import Integer from concrete.common.extensions.table import LookupTable -from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter +from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS from concrete.common.values import ClearScalar, EncryptedScalar from concrete.common.values.tensors import ClearTensor, EncryptedTensor from concrete.numpy.compile import compile_numpy_function_into_op_graph +from concrete.numpy.np_mlir_converter import NPMLIRConverter def add(x, y): @@ -219,7 +220,7 @@ def test_mlir_converter(func, args_dict, args_ranges): """Test the conversion to MLIR by calling the parser from the compiler""" inputset = datagen(*args_ranges) result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset) - converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) mlir_result = converter.convert(result_graph) # testing that this doesn't raise an error compiler.round_trip(mlir_result) @@ -261,7 +262,7 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges): for data in datagen(*args_ranges) ), ) - converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) mlir_result = converter.convert(result_graph) # testing that this doesn't raise an error compiler.round_trip(mlir_result) @@ -281,7 +282,7 @@ def test_mlir_converter_dot_vector_and_constant(): {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))}, [(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)], ) - left_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) left_mlir = left_converter.convert(left_graph) right_graph = compile_numpy_function_into_op_graph( @@ -289,7 +290,7 @@ def test_mlir_converter_dot_vector_and_constant(): {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))}, [(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)], ) - right_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) right_mlir = right_converter.convert(right_graph) # testing that this doesn't raise an error @@ -300,7 +301,7 @@ def test_mlir_converter_dot_vector_and_constant(): def test_concrete_encrypted_integer_to_mlir_type(): """Test conversion of EncryptedScalar into MLIR""" value = EncryptedScalar(Integer(7, is_signed=False)) - converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) eint = converter.common_value_to_mlir_type(value) assert eint == hlfhe.EncryptedIntegerType.get(converter.context, 7) @@ -309,7 +310,7 @@ def test_concrete_encrypted_integer_to_mlir_type(): def test_concrete_clear_integer_to_mlir_type(is_signed): """Test conversion of ClearScalar into MLIR""" value = ClearScalar(Integer(5, is_signed=is_signed)) - converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) with converter.context: int_mlir = converter.common_value_to_mlir_type(value) if is_signed: @@ -330,7 +331,7 @@ def test_concrete_clear_integer_to_mlir_type(is_signed): def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape): """Test conversion of ClearTensor into MLIR""" value = ClearTensor(Integer(5, is_signed=is_signed), shape) - converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) with converter.context, Location.unknown(): tensor_mlir = converter.common_value_to_mlir_type(value) if is_signed: @@ -355,7 +356,7 @@ def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape): def test_concrete_encrypted_tensor_integer_to_mlir_type(shape): """Test conversion of EncryptedTensor into MLIR""" value = EncryptedTensor(Integer(6, is_signed=False), shape) - converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) with converter.context, Location.unknown(): tensor_mlir = converter.common_value_to_mlir_type(value) element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6) @@ -369,7 +370,7 @@ def test_concrete_encrypted_tensor_integer_to_mlir_type(shape): def test_failing_concrete_to_mlir_type(): """Test failing conversion of an unsupported type into MLIR""" value = "random" - converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) with pytest.raises(TypeError, match=r"can't convert value of type .* to MLIR type"): converter.common_value_to_mlir_type(value)