mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(mlir): generate tables before converting nodes to MLIR
- MLIRConverter becomes an abstract base class - pass the needed informations in an extra dict to MLIR converters - NPMLIRConverter handles the specifics of numpy tables generation
This commit is contained in:
@@ -26,7 +26,7 @@ from ..representation.intermediate import Add, Constant, Dot, Mul, Sub, Univaria
|
||||
from ..values import TensorValue
|
||||
|
||||
|
||||
def add(node, preds, ir_to_mlir_node, ctx):
|
||||
def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert an addition intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "addition should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "addition should have a single output")
|
||||
@@ -70,7 +70,7 @@ def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
).result
|
||||
|
||||
|
||||
def sub(node, preds, ir_to_mlir_node, ctx):
|
||||
def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a subtraction intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "subtraction should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "subtraction should have a single output")
|
||||
@@ -94,7 +94,7 @@ def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
).result
|
||||
|
||||
|
||||
def mul(node, preds, ir_to_mlir_node, ctx):
|
||||
def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a multiplication intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "multiplication should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "multiplication should have a single output")
|
||||
@@ -123,7 +123,7 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
).result
|
||||
|
||||
|
||||
def constant(node, _, __, ctx):
|
||||
def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a constant input."""
|
||||
value = node.outputs[0]
|
||||
|
||||
@@ -164,7 +164,7 @@ def constant(node, _, __, ctx):
|
||||
raise TypeError(f"Don't support {value} constants")
|
||||
|
||||
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
|
||||
"""Convert a UnivariateFunction intermediate node."""
|
||||
assert_true(len(node.inputs) == 1, "LUT should have a single input")
|
||||
assert_true(len(node.outputs) == 1, "LUT should have a single output")
|
||||
@@ -181,7 +181,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
x_node = preds[0]
|
||||
x = ir_to_mlir_node[x_node]
|
||||
table = node.get_table()
|
||||
table = additional_conversion_info["tables"][node]
|
||||
out_dtype = cast(Integer, node.outputs[0].dtype)
|
||||
# Create table
|
||||
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
|
||||
@@ -196,7 +196,7 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx):
|
||||
).result
|
||||
|
||||
|
||||
def dot(node, preds, ir_to_mlir_node, ctx):
|
||||
def dot(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a dot intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "Dot should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "Dot should have a single output")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""File containing code to convert a DAG containing ir nodes to the compiler opset."""
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
from typing import Tuple, cast
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Tuple, cast
|
||||
|
||||
import networkx as nx
|
||||
import zamalang
|
||||
@@ -22,7 +23,7 @@ from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Input
|
||||
|
||||
|
||||
class MLIRConverter:
|
||||
class MLIRConverter(ABC):
|
||||
"""Converter of the common IR to MLIR."""
|
||||
|
||||
def __init__(self, conversion_functions: dict) -> None:
|
||||
@@ -87,6 +88,18 @@ class MLIRConverter:
|
||||
# unsigned integer are considered signless in the compiler
|
||||
return IntegerType.get_signless(bit_width)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
|
||||
"""Generate the additional_conversion_info dict for the MLIR converter.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph for which we need the conversion infos.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The dict with the additional conversion infos.
|
||||
"""
|
||||
|
||||
def common_value_to_mlir_type(self, value: values.BaseValue) -> MLIRType:
|
||||
"""Convert a common value to its corresponding MLIR Type.
|
||||
|
||||
@@ -125,11 +138,16 @@ class MLIRConverter:
|
||||
"""Convert the graph of IntermediateNode to an MLIR textual representation.
|
||||
|
||||
Args:
|
||||
graph: graph of IntermediateNode to be converted
|
||||
op_graph (OPGraph): graph of IntermediateNode to be converted
|
||||
|
||||
Raises:
|
||||
NotImplementedError: raised if an unknown node type is encountered.
|
||||
|
||||
Returns:
|
||||
textual MLIR representation
|
||||
str: textual MLIR representation
|
||||
"""
|
||||
additional_conversion_info = self._generate_additional_info_dict(op_graph)
|
||||
|
||||
with self.context, Location.unknown():
|
||||
module = Module.create()
|
||||
# collect inputs
|
||||
@@ -162,7 +180,13 @@ class MLIRConverter:
|
||||
idx_to_pred[data["input_idx"]] = pred
|
||||
preds = [idx_to_pred[i] for i in range(len(idx_to_pred))]
|
||||
# convert to mlir
|
||||
result = mlir_op(node, preds, ir_to_mlir_node, self.context)
|
||||
result = mlir_op(
|
||||
node,
|
||||
preds,
|
||||
ir_to_mlir_node,
|
||||
self.context,
|
||||
additional_conversion_info,
|
||||
)
|
||||
ir_to_mlir_node[node] = result
|
||||
|
||||
results = (
|
||||
|
||||
@@ -13,7 +13,7 @@ from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.debugging import get_printable_graph
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
|
||||
from ..common.mlir.utils import (
|
||||
check_graph_values_compatibility_with_mlir,
|
||||
extend_direct_lookup_tables,
|
||||
@@ -29,6 +29,7 @@ from .np_dtypes_helpers import (
|
||||
get_base_value_for_numpy_or_python_constant_data,
|
||||
get_constructor_for_numpy_or_python_constant_data,
|
||||
)
|
||||
from .np_mlir_converter import NPMLIRConverter
|
||||
|
||||
|
||||
def numpy_max_func(lhs: Any, rhs: Any) -> Any:
|
||||
@@ -281,11 +282,8 @@ def _compile_numpy_function_internal(
|
||||
)
|
||||
|
||||
# Convert graph to an MLIR representation
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
|
||||
# Disable numpy warnings during conversion to avoid issues during TLU generation
|
||||
with numpy.errstate(all="ignore"):
|
||||
mlir_result = converter.convert(op_graph)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
# Show MLIR representation if requested
|
||||
if show_mlir:
|
||||
|
||||
36
concrete/numpy/np_mlir_converter.py
Normal file
36
concrete/numpy/np_mlir_converter.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Numpy-specific MLIR converter."""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy
|
||||
|
||||
from ..common.mlir.mlir_converter import MLIRConverter
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import UnivariateFunction
|
||||
|
||||
|
||||
class NPMLIRConverter(MLIRConverter):
|
||||
"""Numpy-specific MLIR converter."""
|
||||
|
||||
@staticmethod
|
||||
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
|
||||
"""Generate the additional_conversion_info dict for the MLIR converter.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph for which we need the conversion infos.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The dict with the additional conversion infos.
|
||||
"""
|
||||
|
||||
additional_conversion_info = {}
|
||||
|
||||
# Disable numpy warnings during conversion to avoid issues during TLU generation
|
||||
with numpy.errstate(all="ignore"):
|
||||
additional_conversion_info["tables"] = {
|
||||
node: node.get_table()
|
||||
for node in op_graph.graph.nodes()
|
||||
if isinstance(node, UnivariateFunction)
|
||||
}
|
||||
|
||||
return additional_conversion_info
|
||||
@@ -63,6 +63,7 @@ def test_fail_tlu_input(input_node):
|
||||
[None],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@@ -84,4 +85,5 @@ def test_fail_tlu_output(input_node):
|
||||
[None],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -10,10 +10,11 @@ from zamalang.dialects import hlfhe
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
|
||||
from concrete.common.values import ClearScalar, EncryptedScalar
|
||||
from concrete.common.values.tensors import ClearTensor, EncryptedTensor
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
from concrete.numpy.np_mlir_converter import NPMLIRConverter
|
||||
|
||||
|
||||
def add(x, y):
|
||||
@@ -219,7 +220,7 @@ def test_mlir_converter(func, args_dict, args_ranges):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
inputset = datagen(*args_ranges)
|
||||
result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
compiler.round_trip(mlir_result)
|
||||
@@ -261,7 +262,7 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
|
||||
for data in datagen(*args_ranges)
|
||||
),
|
||||
)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
compiler.round_trip(mlir_result)
|
||||
@@ -281,7 +282,7 @@ def test_mlir_converter_dot_vector_and_constant():
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
)
|
||||
left_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
left_mlir = left_converter.convert(left_graph)
|
||||
|
||||
right_graph = compile_numpy_function_into_op_graph(
|
||||
@@ -289,7 +290,7 @@ def test_mlir_converter_dot_vector_and_constant():
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
)
|
||||
right_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
right_mlir = right_converter.convert(right_graph)
|
||||
|
||||
# testing that this doesn't raise an error
|
||||
@@ -300,7 +301,7 @@ def test_mlir_converter_dot_vector_and_constant():
|
||||
def test_concrete_encrypted_integer_to_mlir_type():
|
||||
"""Test conversion of EncryptedScalar into MLIR"""
|
||||
value = EncryptedScalar(Integer(7, is_signed=False))
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
eint = converter.common_value_to_mlir_type(value)
|
||||
assert eint == hlfhe.EncryptedIntegerType.get(converter.context, 7)
|
||||
|
||||
@@ -309,7 +310,7 @@ def test_concrete_encrypted_integer_to_mlir_type():
|
||||
def test_concrete_clear_integer_to_mlir_type(is_signed):
|
||||
"""Test conversion of ClearScalar into MLIR"""
|
||||
value = ClearScalar(Integer(5, is_signed=is_signed))
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with converter.context:
|
||||
int_mlir = converter.common_value_to_mlir_type(value)
|
||||
if is_signed:
|
||||
@@ -330,7 +331,7 @@ def test_concrete_clear_integer_to_mlir_type(is_signed):
|
||||
def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape):
|
||||
"""Test conversion of ClearTensor into MLIR"""
|
||||
value = ClearTensor(Integer(5, is_signed=is_signed), shape)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with converter.context, Location.unknown():
|
||||
tensor_mlir = converter.common_value_to_mlir_type(value)
|
||||
if is_signed:
|
||||
@@ -355,7 +356,7 @@ def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape):
|
||||
def test_concrete_encrypted_tensor_integer_to_mlir_type(shape):
|
||||
"""Test conversion of EncryptedTensor into MLIR"""
|
||||
value = EncryptedTensor(Integer(6, is_signed=False), shape)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with converter.context, Location.unknown():
|
||||
tensor_mlir = converter.common_value_to_mlir_type(value)
|
||||
element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6)
|
||||
@@ -369,7 +370,7 @@ def test_concrete_encrypted_tensor_integer_to_mlir_type(shape):
|
||||
def test_failing_concrete_to_mlir_type():
|
||||
"""Test failing conversion of an unsupported type into MLIR"""
|
||||
value = "random"
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with pytest.raises(TypeError, match=r"can't convert value of type .* to MLIR type"):
|
||||
converter.common_value_to_mlir_type(value)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user