diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 5cfaf322b..77430f9b7 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -100,23 +100,6 @@ def value_is_unsigned_integer(value_to_check: BaseValue) -> bool: ) -def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool: - """Check that a value is encrypted and is of type unsigned Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is encrypted and is of type unsigned Integer - """ - - return ( - value_to_check.is_encrypted - and isinstance(value_to_check.dtype, INTEGER_TYPES) - and not cast(Integer, value_to_check.dtype).is_signed - ) - - def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool: """Check that a value is an encrypted TensorValue of type Integer. @@ -129,22 +112,6 @@ def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool: return value_is_tensor_integer(value_to_check) and value_to_check.is_encrypted -def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> bool: - """Check that a value is an encrypted TensorValue of type unsigned Integer. - - Args: - value_to_check (BaseValue): The value to check - - Returns: - bool: True if the passed value_to_check is an encrypted TensorValue of type Integer and - unsigned - """ - return ( - value_is_encrypted_tensor_integer(value_to_check) - and not cast(Integer, value_to_check.dtype).is_signed - ) - - def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool: """Check that a value is a clear TensorValue of type Integer. diff --git a/concrete/common/mlir/__init__.py b/concrete/common/mlir/__init__.py index bb04b2887..b06eb9322 100644 --- a/concrete/common/mlir/__init__.py +++ b/concrete/common/mlir/__init__.py @@ -1,5 +1,3 @@ -"""MLIR conversion submodule.""" -from .converters import V0_OPSET_CONVERSION_FUNCTIONS -from .mlir_converter import MLIRConverter +"""MLIR conversion module.""" -__all__ = ["MLIRConverter", "V0_OPSET_CONVERSION_FUNCTIONS"] +from .graph_converter import OPGraphConverter diff --git a/concrete/common/mlir/conversion_helpers.py b/concrete/common/mlir/conversion_helpers.py new file mode 100644 index 000000000..1f9fa7f96 --- /dev/null +++ b/concrete/common/mlir/conversion_helpers.py @@ -0,0 +1,64 @@ +"""Helpers for MLIR conversion functionality.""" + +# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints + +# pylint: disable=no-name-in-module + +from typing import Optional + +from mlir.ir import Context, IntegerType, RankedTensorType, Type +from zamalang.dialects.hlfhe import EncryptedIntegerType + +from ..data_types import Integer +from ..values import BaseValue, TensorValue + +# pylint: enable=no-name-in-module + + +def integer_to_mlir_type(ctx: Context, integer: Integer, is_encrypted: bool) -> Optional[Type]: + """Convert an integer to its corresponding MLIR type. + + Args: + ctx (Context): the MLIR context to perform the conversion + integer (Integer): the integer to convert + is_encrypted (bool): whether the integer is encrypted or not + + Returns: + Type: + the MLIR type corresponding to given integer and encryption status + if it's supported otherwise None + """ + + bit_width = integer.bit_width + + if is_encrypted: + result = EncryptedIntegerType.get(ctx, bit_width) + else: + result = IntegerType.get_signless(bit_width) + + return result + + +def value_to_mlir_type(ctx: Context, value: BaseValue) -> Type: + """Convert a value to its corresponding MLIR type. + + Args: + ctx (Context): the MLIR context to perform the conversion + value (BaseValue): the value to convert + + Returns: + Type: the MLIR type corresponding to given value + """ + + dtype = value.dtype + if isinstance(dtype, Integer): + try: + mlir_type = integer_to_mlir_type(ctx, dtype, value.is_encrypted) + if isinstance(value, TensorValue): + if not value.is_scalar: + mlir_type = RankedTensorType.get(value.shape, mlir_type) + return mlir_type + except ValueError: + pass # the error below will be raised + + raise TypeError(f"{value} is not supported for MLIR conversion") diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py deleted file mode 100644 index 6be25b6ac..000000000 --- a/concrete/common/mlir/converters.py +++ /dev/null @@ -1,383 +0,0 @@ -"""Converter functions from the common IR to MLIR. - -Converter functions all have the same signature `converter(node, preds, ir_to_mlir_node, ctx)` -- `node`: IntermediateNode to be converted -- `preds`: List of predecessors of `node` ordered as operands -- `ir_to_mlir_node`: Dict mapping intermediate nodes to MLIR nodes or values -- `ctx`: MLIR context -""" -from typing import cast - -# pylint: disable=no-name-in-module,no-member -import numpy -from mlir.dialects import arith as arith_dialect -from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType -from zamalang.dialects import hlfhe, hlfhelinalg - -from ..data_types.dtypes_helpers import ( - value_is_clear_scalar_integer, - value_is_clear_tensor_integer, - value_is_encrypted_tensor_integer, - value_is_encrypted_unsigned_integer, - value_is_scalar_integer, - value_is_tensor_integer, -) -from ..data_types.integers import Integer -from ..debugging.custom_assert import assert_true -from ..representation.intermediate import Add, Constant, Dot, GenericFunction, Mul, Sub -from ..values import TensorValue - - -def _convert_scalar_constant_op_to_single_element_tensor_constant_op(operation): - """Convert a scalar constant operation result to a dense tensor constant operation result. - - see https://github.com/zama-ai/concretefhe-internal/issues/837. - - This is a temporary workaround before the compiler natively supports - `tensor + scalar`, `tensor - scalar`, `tensor * scalar` operations. - - Example input = `%c3_i4 = arith.constant 3 : i4` - Example output = `%cst = arith.constant dense<3> : tensor<1xi4>` - - Args: - operation: operation to convert - - Returns: - the converted operation - """ - - operation_str = str(operation) - - constant_start_location = operation_str.find("arith.constant") + len("arith.constant") + 1 - constant_end_location = operation_str.find(f": {str(operation.type)}") - 1 - constant_value = operation_str[constant_start_location:constant_end_location] - - resulting_type = RankedTensorType.get((1,), operation.type) - value_attr = Attribute.parse(f"dense<{constant_value}> : tensor<1x{str(operation.type)}>") - - return arith_dialect.ConstantOp(resulting_type, value_attr).result - - -def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): - """Convert an addition intermediate node.""" - assert_true(len(node.inputs) == 2, "addition should have two inputs") - assert_true(len(node.outputs) == 1, "addition should have a single output") - - is_convertible = True - one_of_the_inputs_is_a_tensor = False - both_of_the_inputs_are_encrypted = True - ordered_preds = preds - - for input_ in node.inputs: - if value_is_tensor_integer(input_): - one_of_the_inputs_is_a_tensor = True - elif not value_is_scalar_integer(input_): - is_convertible = False - - if not is_convertible: - raise TypeError( - f"Don't support addition between {str(node.inputs[0])} and {str(node.inputs[1])}" - ) - - if node.inputs[1].is_clear: - both_of_the_inputs_are_encrypted = False - if node.inputs[0].is_clear: - both_of_the_inputs_are_encrypted = False - ordered_preds = preds[::-1] - - if one_of_the_inputs_is_a_tensor: - if both_of_the_inputs_are_encrypted: - return _linalg_add_eint_eint(node, ordered_preds, ir_to_mlir_node, ctx) - return _linalg_add_eint_int(node, ordered_preds, ir_to_mlir_node, ctx) - - if both_of_the_inputs_are_encrypted: - return _add_eint_eint(node, ordered_preds, ir_to_mlir_node, ctx) - return _add_eint_int(node, ordered_preds, ir_to_mlir_node, ctx) - - -def _add_eint_int(node, preds, ir_to_mlir_node, ctx): - """Convert an addition intermediate node with (eint, int).""" - lhs_node, rhs_node = preds - lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] - return hlfhe.AddEintIntOp( - hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width), - lhs, - rhs, - ).result - - -def _add_eint_eint(node, preds, ir_to_mlir_node, ctx): - """Convert an addition intermediate node with (eint, eint).""" - lhs_node, rhs_node = preds - lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] - return hlfhe.AddEintOp( - hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width), - lhs, - rhs, - ).result - - -def _linalg_add_eint_int(node, preds, ir_to_mlir_node, ctx): - """Convert an addition intermediate tensor node with (eint, int).""" - lhs_node, rhs_node = preds - lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] - - if not str(rhs.type).startswith("tensor"): - rhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(rhs) - - int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width) - vec_type = RankedTensorType.get(node.outputs[0].shape, int_type) - - return hlfhelinalg.AddEintIntOp(vec_type, lhs, rhs).result - - -def _linalg_add_eint_eint(node, preds, ir_to_mlir_node, ctx): - """Convert an addition intermediate tensor node with (eint, eint).""" - lhs_node, rhs_node = preds - lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] - - int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width) - vec_type = RankedTensorType.get(node.outputs[0].shape, int_type) - - return hlfhelinalg.AddEintOp(vec_type, lhs, rhs).result - - -def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): - """Convert a subtraction intermediate node.""" - assert_true(len(node.inputs) == 2, "subtraction should have two inputs") - assert_true(len(node.outputs) == 1, "subtraction should have a single output") - - is_convertible = True - one_of_the_inputs_is_a_tensor = False - - if value_is_clear_tensor_integer(node.inputs[0]): - one_of_the_inputs_is_a_tensor = True - elif not value_is_clear_scalar_integer(node.inputs[0]): - is_convertible = False - - if value_is_tensor_integer(node.inputs[1]): - one_of_the_inputs_is_a_tensor = True - elif not value_is_scalar_integer(node.inputs[1]): - is_convertible = False - - if not is_convertible: - raise TypeError( - f"Don't support subtraction between {str(node.inputs[0])} and {str(node.inputs[1])}" - ) - - if one_of_the_inputs_is_a_tensor: - return _linalg_sub_int_eint(node, preds, ir_to_mlir_node, ctx) - return _sub_int_eint(node, preds, ir_to_mlir_node, ctx) - - -def _sub_int_eint(node, preds, ir_to_mlir_node, ctx): - """Convert a subtraction intermediate node with (int, eint).""" - lhs_node, rhs_node = preds - lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] - return hlfhe.SubIntEintOp( - hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width), - lhs, - rhs, - ).result - - -def _linalg_sub_int_eint(node, preds, ir_to_mlir_node, ctx): - """Convert a subtraction intermediate node with (int, eint).""" - lhs_node, rhs_node = preds - lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] - - if not str(lhs.type).startswith("tensor"): - lhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(lhs) - - int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width) - vec_type = RankedTensorType.get(node.outputs[0].shape, int_type) - - return hlfhelinalg.SubIntEintOp(vec_type, lhs, rhs).result - - -def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): - """Convert a multiplication intermediate node.""" - assert_true(len(node.inputs) == 2, "multiplication should have two inputs") - assert_true(len(node.outputs) == 1, "multiplication should have a single output") - - is_convertible = True - one_of_the_inputs_is_a_tensor = False - ordered_preds = preds - - for input_ in node.inputs: - if value_is_tensor_integer(input_): - one_of_the_inputs_is_a_tensor = True - elif not value_is_scalar_integer(input_): - is_convertible = False - - if not is_convertible: - raise TypeError( - f"Don't support multiplication between {str(node.inputs[0])} and {str(node.inputs[1])}" - ) - - if node.inputs[0].is_clear: - ordered_preds = preds[::-1] - - if one_of_the_inputs_is_a_tensor: - return _linalg_mul_eint_int(node, ordered_preds, ir_to_mlir_node, ctx) - return _mul_eint_int(node, ordered_preds, ir_to_mlir_node, ctx) - - -def _mul_eint_int(node, preds, ir_to_mlir_node, ctx): - """Convert a multiplication intermediate node with (eint, int).""" - lhs_node, rhs_node = preds - lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] - return hlfhe.MulEintIntOp( - hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width), - lhs, - rhs, - ).result - - -def _linalg_mul_eint_int(node, preds, ir_to_mlir_node, ctx): - """Convert a subtraction intermediate node with (int, eint).""" - lhs_node, rhs_node = preds - lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] - - if not str(rhs.type).startswith("tensor"): - rhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(rhs) - - int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width) - vec_type = RankedTensorType.get(node.outputs[0].shape, int_type) - - return hlfhelinalg.MulEintIntOp(vec_type, lhs, rhs).result - - -def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=None): - """Convert a constant input.""" - value = node.outputs[0] - - if value_is_clear_scalar_integer(value): - value = cast(TensorValue, value) - - dtype = cast(Integer, value.dtype) - data = node.constant_data - - int_type = IntegerType.get_signless(dtype.bit_width, context=ctx) - return arith_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, data)).result - - if value_is_clear_tensor_integer(value): - value = cast(TensorValue, value) - - dtype = cast(Integer, value.dtype) - data = node.constant_data - - int_type = IntegerType.get_signless(dtype.bit_width, context=ctx) - vec_type = RankedTensorType.get(value.shape, int_type) - - # usage of `Attribute.parse` is the result of some limitations in the MLIR module - # provided by LLVM - - # `DenseElementsAttr` should have been used instead but it's impossible to assign - # custom bit-widths using it (e.g., uint5) - - # since we coudn't create a `DenseElementsAttr` with a custom bit width using python api - # we use `Attribute.parse` to let the underlying library do it by itself - - value_attr = Attribute.parse(f"dense<{str(data.tolist())}> : {vec_type}") - return arith_dialect.ConstantOp(vec_type, value_attr).result - - raise TypeError(f"Don't support {value} constants") - - -def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info): - """Convert a GenericFunction intermediate node.""" - - variable_input_indices = [ - idx for idx, pred in enumerate(preds) if not isinstance(pred, Constant) - ] - - assert_true( - (non_constant_pred_count := len(variable_input_indices)) == 1, - f"LUT should have a single variable input (got {non_constant_pred_count})", - ) - - variable_input_idx = variable_input_indices[0] - variable_input_value = node.inputs[variable_input_idx] - - assert_true(len(node.outputs) == 1, "LUT should have a single output") - if not value_is_encrypted_unsigned_integer(variable_input_value): - raise TypeError( - f"Only support LUT with encrypted unsigned integers inputs " - f"(but {variable_input_value} is provided)" - ) - if not value_is_encrypted_unsigned_integer(node.outputs[0]): - raise TypeError( - f"Only support LUT with encrypted unsigned integers outputs " - f"(but {node.outputs[0]} is provided)" - ) - - x_node = preds[variable_input_idx] - x = ir_to_mlir_node[x_node] - tables = additional_conversion_info["tables"][node] - - # TODO: #559 adapt the code to support multi TLUs - # This cannot be reached today as compilation fails if the intermediate values are not all - # scalars - if len(tables) > 1: # pragma: no cover - raise RuntimeError( - "MLIR conversion currently does not support multiple test vectors for LUT" - ) - - table = tables[0][0] - - out_dtype = cast(Integer, node.outputs[0].dtype) - # Create table - dense_elem = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=ctx) - tensor_lut = arith_dialect.ConstantOp( - RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)), - dense_elem, - ).result - - int_type = hlfhe.EncryptedIntegerType.get(ctx, out_dtype.bit_width) - - if value_is_encrypted_tensor_integer(node.inputs[0]): - vec_type = RankedTensorType.get(node.outputs[0].shape, int_type) - return hlfhelinalg.ApplyLookupTableEintOp(vec_type, x, tensor_lut).result - return hlfhe.ApplyLookupTableEintOp(int_type, x, tensor_lut).result - - -def dot(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None): - """Convert a dot intermediate node.""" - assert_true(len(node.inputs) == 2, "Dot should have two inputs") - assert_true(len(node.outputs) == 1, "Dot should have a single output") - if not ( - ( - value_is_encrypted_tensor_integer(node.inputs[0]) - and value_is_clear_tensor_integer(node.inputs[1]) - ) - or ( - value_is_encrypted_tensor_integer(node.inputs[1]) - and value_is_clear_tensor_integer(node.inputs[0]) - ) - ): - raise TypeError( - f"Don't support dot between {str(node.inputs[0])} and {str(node.inputs[1])}" - ) - lhs_node, rhs_node = preds - # need to flip as underlying operation need encrypted first - if value_is_clear_tensor_integer(node.inputs[0]): - lhs_node, rhs_node = rhs_node, lhs_node - lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node] - return hlfhelinalg.Dot( - hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width), - lhs, - rhs, - ).result - - -V0_OPSET_CONVERSION_FUNCTIONS = { - Add: add, - Sub: sub, - Mul: mul, - Constant: constant, - GenericFunction: apply_lut, - Dot: dot, -} - -# pylint: enable=no-name-in-module,no-member diff --git a/concrete/common/mlir/graph_converter.py b/concrete/common/mlir/graph_converter.py new file mode 100644 index 000000000..2cf51b23c --- /dev/null +++ b/concrete/common/mlir/graph_converter.py @@ -0,0 +1,79 @@ +"""Module that provides OPGraph conversion functionality.""" + +# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints + +# pylint: disable=no-name-in-module + +from abc import ABC, abstractmethod +from typing import Any, Dict + +import networkx as nx +import zamalang +from mlir.dialects import builtin +from mlir.ir import Context, InsertionPoint, Location, Module + +from ..operator_graph import OPGraph +from ..representation.intermediate import Input +from .conversion_helpers import value_to_mlir_type +from .node_converter import IntermediateNodeConverter + +# pylint: enable=no-name-in-module + + +class OPGraphConverter(ABC): + """Converter of OPGraph to MLIR.""" + + def convert(self, opgraph: OPGraph) -> str: + """Convert an operation graph to its corresponding MLIR representation. + + Args: + opgraph (OPGraph): the operation graph to be converted + + Returns: + str: textual MLIR representation corresponding to opgraph + """ + + additional_conversion_info = self._generate_additional_info_dict(opgraph) + + with Context() as ctx, Location.unknown(): + zamalang.register_dialects(ctx) + + module = Module.create() + with InsertionPoint(module.body): + parameters = [ + value_to_mlir_type(ctx, input_node.outputs[0]) + for input_node in opgraph.get_ordered_inputs() + ] + + @builtin.FuncOp.from_py_func(*parameters) + def main(*arg): + ir_to_mlir = {} + for arg_num, node in opgraph.input_nodes.items(): + ir_to_mlir[node] = arg[arg_num] + + for node in nx.topological_sort(opgraph.graph): + if isinstance(node, Input): + continue + + preds = [ir_to_mlir[pred] for pred in opgraph.get_ordered_preds(node)] + node_converter = IntermediateNodeConverter(ctx, opgraph, node, preds) + ir_to_mlir[node] = node_converter.convert(additional_conversion_info) + + results = ( + ir_to_mlir[output_node] for output_node in opgraph.get_ordered_outputs() + ) + return results + + return str(module) + + @staticmethod + @abstractmethod + def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]: + """Generate additional conversion info dict for the MLIR converter. + + Args: + opgraph (OPGraph): the operation graph from which the additional info will be generated + + Returns: + Dict[str, Any]: dict of additional conversion info + """ diff --git a/concrete/common/mlir/mlir_converter.py b/concrete/common/mlir/mlir_converter.py deleted file mode 100644 index b132b1220..000000000 --- a/concrete/common/mlir/mlir_converter.py +++ /dev/null @@ -1,194 +0,0 @@ -"""File containing code to convert a DAG containing ir nodes to the compiler opset.""" -# pylint: disable=no-name-in-module,no-member -from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple, cast - -import networkx as nx -import zamalang -from mlir.dialects import builtin -from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module, RankedTensorType -from mlir.ir import Type as MLIRType -from zamalang.dialects import hlfhe - -from .. import values -from ..data_types import Integer -from ..data_types.dtypes_helpers import ( - value_is_clear_scalar_integer, - value_is_clear_tensor_integer, - value_is_encrypted_scalar_unsigned_integer, - value_is_encrypted_tensor_unsigned_integer, -) -from ..debugging.custom_assert import assert_true -from ..operator_graph import OPGraph -from ..representation.intermediate import Input - - -class MLIRConverter(ABC): - """Converter of the common IR to MLIR.""" - - def __init__(self, conversion_functions: dict) -> None: - """Instantiate a converter with a given set of converters. - - Args: - conversion_functions (dict): mapping common IR nodes to functions that generate MLIR. - every function should have 4 arguments: - - node (IntermediateNode): the node itself to be converted - - operands (IntermediateNode): predecessors of node ordered as operands - - ir_to_mlir_node (dict): mapping between IntermediateNode and their equivalent - MLIR values - - context (mlir.Context): the MLIR context being used for the conversion - """ - self.conversion_functions = conversion_functions - self._init_context() - - def _init_context(self): - self.context = Context() - zamalang.register_dialects(self.context) - - def _get_tensor_type( - self, - bit_width: int, - is_encrypted: bool, - is_signed: bool, - shape: Tuple[int, ...], - ) -> MLIRType: - """Get the MLIRType for a tensor element given its properties. - - Args: - bit_width (int): number of bits used for the scalar - is_encrypted (bool): is the scalar encrypted or not - is_signed (bool): is the scalar signed or not - shape (Tuple[int, ...]): shape of the tensor - - Returns: - MLIRType: corresponding MLIR type - """ - element_type = self._get_scalar_integer_type(bit_width, is_encrypted, is_signed) - return RankedTensorType.get(shape, element_type) - - def _get_scalar_integer_type( - self, bit_width: int, is_encrypted: bool, is_signed: bool - ) -> MLIRType: - """Get the MLIRType for a scalar element given its properties. - - Args: - bit_width (int): number of bits used for the scalar - is_encrypted (bool): is the scalar encrypted or not - is_signed (bool): is the scalar signed or not - - Returns: - MLIRType: corresponding MLIR type - """ - if is_encrypted and not is_signed: - return hlfhe.EncryptedIntegerType.get(self.context, bit_width) - if is_signed and not is_encrypted: # clear signed - return IntegerType.get_signed(bit_width) - # should be clear unsigned at this point - assert_true(not is_signed and not is_encrypted) - # unsigned integer are considered signless in the compiler - return IntegerType.get_signless(bit_width) - - @staticmethod - @abstractmethod - def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]: - """Generate the additional_conversion_info dict for the MLIR converter. - - Args: - op_graph (OPGraph): the OPGraph for which we need the conversion infos. - - Returns: - Dict[str, Any]: The dict with the additional conversion infos. - """ - - def common_value_to_mlir_type(self, value: values.BaseValue) -> MLIRType: - """Convert a common value to its corresponding MLIR Type. - - Args: - value: value to convert - - Returns: - corresponding MLIR type - """ - if value_is_encrypted_scalar_unsigned_integer(value): - return self._get_scalar_integer_type(cast(Integer, value.dtype).bit_width, True, False) - if value_is_clear_scalar_integer(value): - dtype = cast(Integer, value.dtype) - return self._get_scalar_integer_type( - dtype.bit_width, is_encrypted=False, is_signed=dtype.is_signed - ) - if value_is_encrypted_tensor_unsigned_integer(value): - dtype = cast(Integer, value.dtype) - return self._get_tensor_type( - dtype.bit_width, - is_encrypted=True, - is_signed=False, - shape=cast(values.TensorValue, value).shape, - ) - if value_is_clear_tensor_integer(value): - dtype = cast(Integer, value.dtype) - return self._get_tensor_type( - dtype.bit_width, - is_encrypted=False, - is_signed=dtype.is_signed, - shape=cast(values.TensorValue, value).shape, - ) - raise TypeError(f"can't convert value of type {type(value)} to MLIR type") - - def convert(self, op_graph: OPGraph) -> str: - """Convert the graph of IntermediateNode to an MLIR textual representation. - - Args: - op_graph (OPGraph): graph of IntermediateNode to be converted - - Raises: - NotImplementedError: raised if an unknown node type is encountered. - - Returns: - str: textual MLIR representation - """ - additional_conversion_info = self._generate_additional_info_dict(op_graph) - - with self.context, Location.unknown(): - module = Module.create() - # collect inputs - with InsertionPoint(module.body): - func_types = [ - self.common_value_to_mlir_type(input_node.inputs[0]) - for input_node in op_graph.get_ordered_inputs() - ] - - @builtin.FuncOp.from_py_func(*func_types) - def main(*arg): - ir_to_mlir_node = {} - for arg_num, node in op_graph.input_nodes.items(): - ir_to_mlir_node[node] = arg[arg_num] - for node in nx.topological_sort(op_graph.graph): - if isinstance(node, Input): - continue - mlir_op = self.conversion_functions.get(type(node), None) - if mlir_op is None: # pragma: no cover - raise NotImplementedError( - f"we don't yet support conversion to MLIR of computations using" - f"{type(node)}" - ) - preds = op_graph.get_ordered_preds(node) - # convert to mlir - result = mlir_op( - node, - preds, - ir_to_mlir_node, - self.context, - additional_conversion_info, - ) - ir_to_mlir_node[node] = result - - results = ( - ir_to_mlir_node[output_node] - for output_node in op_graph.get_ordered_outputs() - ) - return results - - return module.__str__() - - -# pylint: enable=no-name-in-module,no-member diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py new file mode 100644 index 000000000..ff1cc1cd5 --- /dev/null +++ b/concrete/common/mlir/node_converter.py @@ -0,0 +1,325 @@ +"""Module that provides IntermediateNode conversion functionality.""" + +# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints + +# pylint: disable=no-name-in-module + +from typing import Any, Dict, List, cast + +import numpy +from mlir.dialects import arith +from mlir.ir import ( + Attribute, + Context, + DenseElementsAttr, + IntegerAttr, + IntegerType, + OpResult, + RankedTensorType, +) +from zamalang.dialects import hlfhe, hlfhelinalg + +from ..data_types import Integer +from ..debugging import assert_true +from ..operator_graph import OPGraph +from ..representation.intermediate import ( + Add, + Constant, + Dot, + GenericFunction, + IntermediateNode, + Mul, + Sub, +) +from ..values import TensorValue +from .conversion_helpers import value_to_mlir_type + +# pylint: enable=no-name-in-module + + +class IntermediateNodeConverter: + """Converter of IntermediateNode to MLIR.""" + + ctx: Context + opgraph: OPGraph + node: IntermediateNode + preds: List[OpResult] + + all_of_the_inputs_are_encrypted: bool + all_of_the_inputs_are_tensors: bool + one_of_the_inputs_is_a_tensor: bool + + def __init__( + self, ctx: Context, opgraph: OPGraph, node: IntermediateNode, preds: List[OpResult] + ): + self.ctx = ctx + self.opgraph = opgraph + self.node = node + self.preds = preds + + self.all_of_the_inputs_are_encrypted = True + self.all_of_the_inputs_are_tensors = True + self.one_of_the_inputs_is_a_tensor = False + + for inp in node.inputs: + if inp.is_clear: + self.all_of_the_inputs_are_encrypted = False + + if isinstance(inp, TensorValue): + if inp.is_scalar: + self.all_of_the_inputs_are_tensors = False + else: + self.one_of_the_inputs_is_a_tensor = True + else: # pragma: no cover + # this branch is not covered as there are only TensorValues for now + self.all_of_the_inputs_are_tensors = False + + def convert(self, additional_conversion_info: Dict[str, Any]) -> OpResult: + """Convert an intermediate node to its corresponding MLIR representation. + + Args: + additional_conversion_info (Dict[str, Any]): + external info that the converted node might need + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + if isinstance(self.node, Add): + return self.convert_add() + + if isinstance(self.node, Constant): + return self.convert_constant() + + if isinstance(self.node, Dot): + return self.convert_dot() + + if isinstance(self.node, GenericFunction): + return self.convert_generic_function(additional_conversion_info) + + if isinstance(self.node, Mul): + return self.convert_mul() + + if isinstance(self.node, Sub): + return self.convert_sub() + + # this statement is not covered as unsupported opeations fail on check mlir compatibility + raise NotImplementedError( + f"{type(self.node)} nodes cannot be converted to MLIR yet" + ) # pragma: no cover + + def convert_add(self) -> OpResult: + """Convert an Add node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + assert_true(len(self.node.inputs) == 2) + assert_true(len(self.node.outputs) == 1) + + resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + preds = self.preds + + if self.all_of_the_inputs_are_encrypted: + if self.one_of_the_inputs_is_a_tensor: + result = hlfhelinalg.AddEintOp(resulting_type, *preds).result + else: + result = hlfhe.AddEintOp(resulting_type, *preds).result + else: + if self.node.inputs[0].is_clear: # pragma: no cover + # this branch is not covered as it's impossible to get into due to how tracing works + # however, it doesn't hurt to keep it as an extra measure + preds = preds[::-1] + + if self.one_of_the_inputs_is_a_tensor: + result = hlfhelinalg.AddEintIntOp(resulting_type, *preds).result + else: + result = hlfhe.AddEintIntOp(resulting_type, *preds).result + + return result + + def convert_constant(self) -> OpResult: + """Convert a Constant node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + assert_true(len(self.node.inputs) == 0) + assert_true(len(self.node.outputs) == 1) + + value = self.node.outputs[0] + if not isinstance(value, TensorValue): # pragma: no cover + # this branch is not covered as there are only TensorValues for now + raise NotImplementedError(f"{value} constants cannot be converted to MLIR yet") + + resulting_type = value_to_mlir_type(self.ctx, value) + data = cast(Constant, self.node).constant_data + + if value.is_scalar: + attr = IntegerAttr.get(resulting_type, data) + else: + # usage of `Attribute.parse` is the result of some limitations in the MLIR module + # provided by LLVM + + # what should have been used is `DenseElementsAttr` but it's impossible to assign + # custom bit-widths using it (e.g., uint5) + + # since we coudn't create a `DenseElementsAttr` with a custom bit width using python api + # we use `Attribute.parse` to let the underlying library do it by itself + + attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}") + + return arith.ConstantOp(resulting_type, attr).result + + def convert_dot(self) -> OpResult: + """Convert a Dot node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + assert_true(len(self.node.inputs) == 2) + assert_true(len(self.node.outputs) == 1) + + if self.all_of_the_inputs_are_encrypted: + lhs = self.node.inputs[0] + rhs = self.node.inputs[1] + raise NotImplementedError( + f"Dot product between {lhs} and {rhs} cannot be converted to MLIR yet", + ) + + resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + preds = self.preds + + if self.node.inputs[0].is_clear: + preds = preds[::-1] + + if self.all_of_the_inputs_are_tensors: + # numpy.dot(x, y) where x and y are both vectors = regular dot product + result = hlfhelinalg.Dot(resulting_type, *preds).result + + elif not self.one_of_the_inputs_is_a_tensor: + # numpy.dot(x, y) where x and y are both scalars = x * y + result = hlfhe.MulEintIntOp(resulting_type, *preds).result + + else: + # numpy.dot(x, y) where one of x or y is a scalar and the other one is a vector = x * y + result = hlfhelinalg.MulEintIntOp(resulting_type, *preds).result + + return result + + def convert_generic_function(self, additional_conversion_info: Dict[str, Any]) -> OpResult: + """Convert a GenericFunction node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + variable_input_indices = [ + idx + for idx, inp in enumerate(self.opgraph.get_ordered_preds(self.node)) + if not isinstance(inp, Constant) + ] + if len(variable_input_indices) != 1: # pragma: no cover + # this branch is not covered as it's impossible to get into due to how tracing works + # however, it doesn't hurt to keep it as an extra measure + raise NotImplementedError( + "Table lookups with more than one variable input cannot be converted to MLIR yet" + ) + variable_input_index = variable_input_indices[0] + + assert_true(len(self.node.outputs) == 1) + + value = self.node.inputs[variable_input_index] + assert_true(value.is_encrypted) + + if not isinstance(value.dtype, Integer) or value.dtype.is_signed: # pragma: no cover + # this branch is not covered as it's impossible to get into due to how compilation works + # however, it doesn't hurt to keep it as an extra measure + raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet") + + tables = additional_conversion_info["tables"][self.node] + + # TODO: #559 adapt the code to support multi TLUs + # This cannot be reached today as compilation fails + # if the intermediate values are not all scalars + if len(tables) > 1: # pragma: no cover + raise NotImplementedError("Multi table lookups cannot be converted to MLIR yet") + + table = tables[0][0] + + lut_size = len(table) + lut_type = RankedTensorType.get([lut_size], IntegerType.get_signless(64, context=self.ctx)) + lut_attr = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=self.ctx) + lut = arith.ConstantOp(lut_type, lut_attr).result + + resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + pred = self.preds[variable_input_index] + + if self.one_of_the_inputs_is_a_tensor: + result = hlfhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result + else: + result = hlfhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result + + return result + + def convert_mul(self) -> OpResult: + """Convert a Mul node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + assert_true(len(self.node.inputs) == 2) + assert_true(len(self.node.outputs) == 1) + + if self.all_of_the_inputs_are_encrypted: + lhs = self.node.inputs[0] + rhs = self.node.inputs[1] + raise NotImplementedError( + f"Multiplication between {lhs} and {rhs} cannot be converted to MLIR yet", + ) + + resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + preds = self.preds + + if self.node.inputs[0].is_clear: # pragma: no cover + # this branch is not covered as it's impossible to get into due to how tracing works + # however, it doesn't hurt to keep it as an extra measure + preds = preds[::-1] + + if self.one_of_the_inputs_is_a_tensor: + result = hlfhelinalg.MulEintIntOp(resulting_type, *preds).result + else: + result = hlfhe.MulEintIntOp(resulting_type, *preds).result + + return result + + def convert_sub(self) -> OpResult: + """Convert a Sub node to its corresponding MLIR representation. + + Returns: + str: textual MLIR representation corresponding to self.node + """ + + assert_true(len(self.node.inputs) == 2) + assert_true(len(self.node.outputs) == 1) + + lhs = self.node.inputs[0] + rhs = self.node.inputs[1] + if not (lhs.is_clear and rhs.is_encrypted): + raise NotImplementedError( + f"Subtraction of {rhs} from {lhs} cannot be converted to MLIR yet", + ) + + resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + preds = self.preds + + if self.one_of_the_inputs_is_a_tensor: + result = hlfhelinalg.SubIntEintOp(resulting_type, *preds).result + else: + result = hlfhe.SubIntEintOp(resulting_type, *preds).result + + return result diff --git a/concrete/common/representation/intermediate.py b/concrete/common/representation/intermediate.py index 6c6362122..8c49ee8de 100644 --- a/concrete/common/representation/intermediate.py +++ b/concrete/common/representation/intermediate.py @@ -475,19 +475,20 @@ class Dot(IntermediateNode): assert_true( all( - isinstance(input_value, TensorValue) and input_value.ndim == 1 + isinstance(input_value, TensorValue) and input_value.ndim <= 1 for input_value in self.inputs ), - f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)", + f"Dot only supports two scalars or vectors ({TensorValue.__name__} with ndim up to 1)", ) lhs = cast(TensorValue, self.inputs[0]) rhs = cast(TensorValue, self.inputs[1]) - assert_true( - lhs.shape[0] == rhs.shape[0], - f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported", - ) + if lhs.ndim == 1 and rhs.ndim == 1: + assert_true( + lhs.shape[0] == rhs.shape[0], + f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported", + ) output_scalar_value = ( EncryptedScalar if (lhs.is_encrypted or rhs.is_encrypted) else ClearScalar diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index 2414c88cb..d08531cdc 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -14,7 +14,6 @@ from ..common.data_types import Integer from ..common.debugging import format_operation_graph from ..common.debugging.custom_assert import assert_true from ..common.fhe_circuit import FHECircuit -from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS from ..common.mlir.utils import ( check_graph_values_compatibility_with_mlir, extend_direct_lookup_tables, @@ -312,7 +311,7 @@ def _compile_numpy_function_internal( prepare_op_graph_for_mlir(op_graph) # Convert graph to an MLIR representation - converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + converter = NPMLIRConverter() mlir_result = converter.convert(op_graph) # Show MLIR representation if requested diff --git a/concrete/numpy/np_mlir_converter.py b/concrete/numpy/np_mlir_converter.py index 5b2f7e668..da85bf836 100644 --- a/concrete/numpy/np_mlir_converter.py +++ b/concrete/numpy/np_mlir_converter.py @@ -8,7 +8,7 @@ from typing import Any, DefaultDict, Dict, List, Tuple import numpy from ..common.debugging import assert_true -from ..common.mlir.mlir_converter import MLIRConverter +from ..common.mlir.graph_converter import OPGraphConverter from ..common.operator_graph import OPGraph from ..common.representation.intermediate import GenericFunction, IntermediateNode @@ -67,27 +67,18 @@ def generate_deduplicated_tables( ) -class NPMLIRConverter(MLIRConverter): +class NPMLIRConverter(OPGraphConverter): """Numpy-specific MLIR converter.""" @staticmethod - def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]: - """Generate the additional_conversion_info dict for the MLIR converter. - - Args: - op_graph (OPGraph): the OPGraph for which we need the conversion infos. - - Returns: - Dict[str, Any]: The dict with the additional conversion infos. - """ - + def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]: additional_conversion_info = {} # Disable numpy warnings during conversion to avoid issues during TLU generation with numpy.errstate(all="ignore"): additional_conversion_info["tables"] = { - node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node)) - for node in op_graph.graph.nodes() + node: generate_deduplicated_tables(node, opgraph.get_ordered_preds(node)) + for node in opgraph.graph.nodes() if isinstance(node, GenericFunction) } diff --git a/tests/common/mlir/test_conversion_helpers.py b/tests/common/mlir/test_conversion_helpers.py new file mode 100644 index 000000000..6da1fcba7 --- /dev/null +++ b/tests/common/mlir/test_conversion_helpers.py @@ -0,0 +1,115 @@ +"""Test file for MLIR conversion helpers.""" + +# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints + +# pylint: disable=no-name-in-module + +import pytest +import zamalang +from mlir.ir import Context, Location + +from concrete.common.data_types import Float, SignedInteger, UnsignedInteger +from concrete.common.mlir.conversion_helpers import integer_to_mlir_type, value_to_mlir_type +from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor + +# pylint: enable=no-name-in-module + + +@pytest.mark.parametrize( + "integer,is_encrypted,expected_mlir_type_str", + [ + pytest.param(SignedInteger(5), False, "i5"), + pytest.param(UnsignedInteger(5), False, "i5"), + pytest.param(SignedInteger(32), False, "i32"), + pytest.param(UnsignedInteger(32), False, "i32"), + pytest.param(SignedInteger(5), True, "!HLFHE.eint<5>"), + pytest.param(UnsignedInteger(5), True, "!HLFHE.eint<5>"), + ], +) +def test_integer_to_mlir_type(integer, is_encrypted, expected_mlir_type_str): + """Test function for integer to MLIR type conversion.""" + + with Context() as ctx, Location.unknown(): + zamalang.register_dialects(ctx) + assert str(integer_to_mlir_type(ctx, integer, is_encrypted)) == expected_mlir_type_str + + +@pytest.mark.parametrize( + "integer,is_encrypted,expected_error_message", + [ + pytest.param(SignedInteger(32), True, "can't create eint with the given width"), + pytest.param(UnsignedInteger(32), True, "can't create eint with the given width"), + ], +) +def test_fail_integer_to_mlir_type(integer, is_encrypted, expected_error_message): + """Test function for failed integer to MLIR type conversion.""" + + with pytest.raises(ValueError) as excinfo: + with Context() as ctx, Location.unknown(): + zamalang.register_dialects(ctx) + integer_to_mlir_type(ctx, integer, is_encrypted) + + assert str(excinfo.value) == expected_error_message + + +@pytest.mark.parametrize( + "value,expected_mlir_type_str", + [ + pytest.param(ClearScalar(SignedInteger(5)), "i5"), + pytest.param(ClearTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"), + pytest.param(EncryptedScalar(SignedInteger(5)), "!HLFHE.eint<5>"), + pytest.param(EncryptedTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3x!HLFHE.eint<5>>"), + pytest.param(ClearScalar(UnsignedInteger(5)), "i5"), + pytest.param(ClearTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"), + pytest.param(EncryptedScalar(UnsignedInteger(5)), "!HLFHE.eint<5>"), + pytest.param( + EncryptedTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3x!HLFHE.eint<5>>" + ), + ], +) +def test_value_to_mlir_type(value, expected_mlir_type_str): + """Test function for value to MLIR type conversion.""" + + with Context() as ctx, Location.unknown(): + zamalang.register_dialects(ctx) + assert str(value_to_mlir_type(ctx, value)) == expected_mlir_type_str + + +@pytest.mark.parametrize( + "value,expected_error_message", + [ + pytest.param( + ClearScalar(Float(32)), + "ClearScalar is not supported for MLIR conversion", + ), + pytest.param( + ClearTensor(Float(32), shape=(2, 3)), + "ClearTensor is not supported for MLIR conversion", + ), + pytest.param( + EncryptedScalar(Float(32)), + "EncryptedScalar is not supported for MLIR conversion", + ), + pytest.param( + EncryptedTensor(Float(32), shape=(2, 3)), + "EncryptedTensor is not supported for MLIR conversion", + ), + pytest.param( + EncryptedScalar(UnsignedInteger(32)), + "EncryptedScalar is not supported for MLIR conversion", + ), + pytest.param( + EncryptedTensor(UnsignedInteger(32), shape=(2, 3)), + "EncryptedTensor is not supported for MLIR conversion", + ), + ], +) +def test_fail_value_to_mlir_type(value, expected_error_message): + """Test function for failed value to MLIR type conversion.""" + + with pytest.raises(TypeError) as excinfo: + with Context() as ctx, Location.unknown(): + zamalang.register_dialects(ctx) + value_to_mlir_type(ctx, value) + + assert str(excinfo.value) == expected_error_message diff --git a/tests/common/mlir/test_converters.py b/tests/common/mlir/test_converters.py deleted file mode 100644 index 8a292e6b5..000000000 --- a/tests/common/mlir/test_converters.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Test converter functions""" -import pytest - -from concrete.common.data_types.floats import Float -from concrete.common.data_types.integers import Integer -from concrete.common.mlir.converters import add, apply_lut, constant, dot, mul, sub -from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar - - -class MockNode: - """Mocking an intermediate node""" - - def __init__(self, inputs_n=5, outputs_n=5, inputs=None, outputs=None): - if inputs is None: - self.inputs = [None for i in range(inputs_n)] - else: - self.inputs = inputs - if outputs is None: - self.outputs = [None for i in range(outputs_n)] - else: - self.outputs = outputs - - -@pytest.mark.parametrize("converter", [add, sub, mul, dot]) -def test_failing_converter(converter): - """Test failing converter""" - with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"): - converter(MockNode(2, 1), None, None, None) - - -def test_fail_non_integer_const(): - """Test failing constant converter with non-integer""" - with pytest.raises(TypeError, match=r"Don't support .* constants"): - constant(MockNode(outputs=[ClearScalar(Float(32))]), None, None, None) - - with pytest.raises(TypeError, match=r"Don't support .* constants"): - constant(MockNode(outputs=[ClearTensor(Float(32), shape=(2,))]), None, None, None) - - -@pytest.mark.parametrize( - "input_node", - [ - ClearScalar(Integer(8, True)), - ClearScalar(Integer(8, False)), - EncryptedScalar(Integer(8, True)), - ], -) -def test_fail_tlu_input(input_node): - """Test failing LUT converter with invalid input""" - with pytest.raises( - TypeError, match=r"Only support LUT with encrypted unsigned integers inputs" - ): - apply_lut( - MockNode(inputs=[input_node], outputs=[EncryptedScalar(Integer(8, False))]), - [None], - None, - None, - None, - ) - - -@pytest.mark.parametrize( - "input_node", - [ - ClearScalar(Integer(8, True)), - ClearScalar(Integer(8, False)), - EncryptedScalar(Integer(8, True)), - ], -) -def test_fail_tlu_output(input_node): - """Test failing LUT converter with invalid output""" - with pytest.raises( - TypeError, match=r"Only support LUT with encrypted unsigned integers outputs" - ): - apply_lut( - MockNode(inputs=[EncryptedScalar(Integer(8, False))], outputs=[input_node]), - [None], - None, - None, - None, - ) diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py deleted file mode 100644 index f5f772aba..000000000 --- a/tests/common/mlir/test_mlir_converter.py +++ /dev/null @@ -1,392 +0,0 @@ -"""Test file for conversion to MLIR""" -# pylint: disable=no-name-in-module,no-member -import itertools - -import numpy -import pytest -from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType -from zamalang import compiler -from zamalang.dialects import hlfhe - -from concrete.common.data_types.integers import Integer -from concrete.common.extensions.table import LookupTable -from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS -from concrete.common.values import ClearScalar, EncryptedScalar -from concrete.common.values.tensors import ClearTensor, EncryptedTensor -from concrete.numpy.compile import compile_numpy_function_into_op_graph, prepare_op_graph_for_mlir -from concrete.numpy.np_mlir_converter import NPMLIRConverter - - -def add(x, y): - """Test simple add""" - return x + y - - -def constant_add(x): - """Test constant add""" - return x + 5 - - -def sub(x, y): - """Test simple sub""" - return x - y - - -def constant_sub(x): - """Test constant sub""" - return 12 - x - - -def mul(x, y): - """Test simple mul""" - return x * y - - -def constant_mul(x): - """Test constant mul""" - return x * 2 - - -def sub_add_mul(x, y, z): - """Test combination of ops""" - return z - y + x * z - - -def ret_multiple(x, y, z): - """Test return of multiple values""" - return x, y, z - - -def ret_multiple_different_order(x, y, z): - """Test return of multiple values in a different order from input""" - return y, z, x - - -def lut(x): - """Test lookup table""" - table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7]) - return table[x] - - -# TODO: remove workaround #359 -def lut_more_bits_than_table_length(x, y): - """Test lookup table when bit_width support longer LUT""" - table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7]) - return table[x] + y - - -# TODO: remove workaround #359 -def lut_less_bits_than_table_length(x): - """Test lookup table when bit_width support smaller LUT""" - table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7, 3, 6, 0, 2, 1, 4, 5, 7]) - return table[x] - - -def dot(x, y): - """Test dot""" - return numpy.dot(x, y) - - -def datagen(*args): - """Generate data from ranges""" - for prod in itertools.product(*args): - yield prod - - -@pytest.mark.parametrize( - "func, args_dict, args_ranges", - [ - ( - add, - { - "x": EncryptedScalar(Integer(64, is_signed=False)), - "y": ClearScalar(Integer(32, is_signed=False)), - }, - (range(0, 8), range(1, 4)), - ), - ( - constant_add, - { - "x": EncryptedScalar(Integer(64, is_signed=False)), - }, - (range(0, 10),), - ), - ( - add, - { - "x": ClearScalar(Integer(32, is_signed=False)), - "y": EncryptedScalar(Integer(64, is_signed=False)), - }, - (range(0, 8), range(1, 4)), - ), - ( - add, - { - "x": EncryptedScalar(Integer(7, is_signed=False)), - "y": EncryptedScalar(Integer(7, is_signed=False)), - }, - (range(7, 15), range(1, 5)), - ), - ( - sub, - { - "x": ClearScalar(Integer(8, is_signed=False)), - "y": EncryptedScalar(Integer(7, is_signed=False)), - }, - (range(5, 10), range(2, 6)), - ), - ( - constant_sub, - { - "x": EncryptedScalar(Integer(64, is_signed=False)), - }, - (range(0, 10),), - ), - ( - mul, - { - "x": EncryptedScalar(Integer(7, is_signed=False)), - "y": ClearScalar(Integer(8, is_signed=False)), - }, - (range(1, 5), range(2, 8)), - ), - ( - constant_mul, - { - "x": EncryptedScalar(Integer(64, is_signed=False)), - }, - (range(0, 10),), - ), - ( - mul, - { - "x": ClearScalar(Integer(8, is_signed=False)), - "y": EncryptedScalar(Integer(7, is_signed=False)), - }, - (range(1, 5), range(2, 8)), - ), - ( - sub_add_mul, - { - "x": EncryptedScalar(Integer(7, is_signed=False)), - "y": EncryptedScalar(Integer(7, is_signed=False)), - "z": ClearScalar(Integer(7, is_signed=False)), - }, - (range(0, 8), range(1, 5), range(5, 12)), - ), - ( - ret_multiple, - { - "x": EncryptedScalar(Integer(7, is_signed=False)), - "y": EncryptedScalar(Integer(7, is_signed=False)), - "z": ClearScalar(Integer(7, is_signed=False)), - }, - (range(1, 5), range(1, 5), range(1, 5)), - ), - ( - ret_multiple_different_order, - { - "x": EncryptedScalar(Integer(7, is_signed=False)), - "y": EncryptedScalar(Integer(7, is_signed=False)), - "z": ClearScalar(Integer(7, is_signed=False)), - }, - (range(1, 5), range(1, 5), range(1, 5)), - ), - ( - lut, - { - "x": EncryptedScalar(Integer(3, is_signed=False)), - }, - (range(0, 8),), - ), - ( - lut_more_bits_than_table_length, - { - "x": EncryptedScalar(Integer(64, is_signed=False)), - "y": EncryptedScalar(Integer(64, is_signed=False)), - }, - (range(0, 8), range(0, 16)), - ), - ( - lut_less_bits_than_table_length, - { - "x": EncryptedScalar(Integer(3, is_signed=False)), - }, - (range(0, 8),), - ), - ], -) -def test_mlir_converter(func, args_dict, args_ranges, default_compilation_configuration): - """Test the conversion to MLIR by calling the parser from the compiler""" - inputset = datagen(*args_ranges) - result_graph = compile_numpy_function_into_op_graph( - func, - args_dict, - inputset, - default_compilation_configuration, - ) - prepare_op_graph_for_mlir(result_graph) - converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - mlir_result = converter.convert(result_graph) - # testing that this doesn't raise an error - compiler.round_trip(mlir_result) - - -@pytest.mark.parametrize( - "func, args_dict, args_ranges", - [ - ( - dot, - { - "x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)), - "y": ClearTensor(Integer(64, is_signed=False), shape=(4,)), - }, - (range(0, 4), range(0, 4)), - ), - ( - dot, - { - "x": ClearTensor(Integer(64, is_signed=False), shape=(4,)), - "y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)), - }, - (range(0, 4), range(0, 4)), - ), - ], -) -def test_mlir_converter_dot_between_vectors( - func, args_dict, args_ranges, default_compilation_configuration -): - """Test the conversion to MLIR by calling the parser from the compiler""" - assert len(args_dict["x"].shape) == 1 - assert len(args_dict["y"].shape) == 1 - - n = args_dict["x"].shape[0] - - result_graph = compile_numpy_function_into_op_graph( - func, - args_dict, - ( - (numpy.array([data[0]] * n), numpy.array([data[1]] * n)) - for data in datagen(*args_ranges) - ), - default_compilation_configuration, - ) - prepare_op_graph_for_mlir(result_graph) - converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - mlir_result = converter.convert(result_graph) - # testing that this doesn't raise an error - compiler.round_trip(mlir_result) - - -def test_mlir_converter_dot_vector_and_constant(default_compilation_configuration): - """Test the conversion to MLIR by calling the parser from the compiler""" - - def left_dot_with_constant(x): - return numpy.dot(x, numpy.array([1, 2])) - - def right_dot_with_constant(x): - return numpy.dot(numpy.array([1, 2]), x) - - left_graph = compile_numpy_function_into_op_graph( - left_dot_with_constant, - {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))}, - [(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)], - default_compilation_configuration, - ) - prepare_op_graph_for_mlir(left_graph) - left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - left_mlir = left_converter.convert(left_graph) - - right_graph = compile_numpy_function_into_op_graph( - right_dot_with_constant, - {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))}, - [(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)], - default_compilation_configuration, - ) - prepare_op_graph_for_mlir(right_graph) - right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - right_mlir = right_converter.convert(right_graph) - - # testing that this doesn't raise an error - compiler.round_trip(left_mlir) - compiler.round_trip(right_mlir) - - -def test_concrete_encrypted_integer_to_mlir_type(): - """Test conversion of EncryptedScalar into MLIR""" - value = EncryptedScalar(Integer(7, is_signed=False)) - converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - eint = converter.common_value_to_mlir_type(value) - assert eint == hlfhe.EncryptedIntegerType.get(converter.context, 7) - - -@pytest.mark.parametrize("is_signed", [True, False]) -def test_concrete_clear_integer_to_mlir_type(is_signed): - """Test conversion of ClearScalar into MLIR""" - value = ClearScalar(Integer(5, is_signed=is_signed)) - converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - with converter.context: - int_mlir = converter.common_value_to_mlir_type(value) - if is_signed: - assert int_mlir == IntegerType.get_signed(5) - else: - assert int_mlir == IntegerType.get_signless(5) - - -@pytest.mark.parametrize("is_signed", [True, False]) -@pytest.mark.parametrize( - "shape", - [ - (5,), - (5, 8), - (-1, 5), - ], -) -def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape): - """Test conversion of ClearTensor into MLIR""" - value = ClearTensor(Integer(5, is_signed=is_signed), shape) - converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - with converter.context, Location.unknown(): - tensor_mlir = converter.common_value_to_mlir_type(value) - if is_signed: - element_type = IntegerType.get_signed(5) - else: - element_type = IntegerType.get_signless(5) - if shape is None: - expected_type = UnrankedTensorType.get(element_type) - else: - expected_type = RankedTensorType.get(shape, element_type) - assert tensor_mlir == expected_type - - -@pytest.mark.parametrize( - "shape", - [ - (5,), - (5, 8), - (-1, 5), - ], -) -def test_concrete_encrypted_tensor_integer_to_mlir_type(shape): - """Test conversion of EncryptedTensor into MLIR""" - value = EncryptedTensor(Integer(6, is_signed=False), shape) - converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - with converter.context, Location.unknown(): - tensor_mlir = converter.common_value_to_mlir_type(value) - element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6) - if shape is None: - expected_type = UnrankedTensorType.get(element_type) - else: - expected_type = RankedTensorType.get(shape, element_type) - assert tensor_mlir == expected_type - - -def test_failing_concrete_to_mlir_type(): - """Test failing conversion of an unsupported type into MLIR""" - value = "random" - converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) - with pytest.raises(TypeError, match=r"can't convert value of type .* to MLIR type"): - converter.common_value_to_mlir_type(value) - - -# pylint: enable=no-name-in-module,no-member diff --git a/tests/common/mlir/test_node_converter.py b/tests/common/mlir/test_node_converter.py new file mode 100644 index 000000000..366c1ec05 --- /dev/null +++ b/tests/common/mlir/test_node_converter.py @@ -0,0 +1,85 @@ +"""Test file for intermediate node to MLIR converter.""" + +import random + +import numpy +import pytest + +from concrete.common.data_types import UnsignedInteger +from concrete.common.values import EncryptedScalar, EncryptedTensor +from concrete.numpy import compile_numpy_function + + +@pytest.mark.parametrize( + "function_to_compile,parameters,inputset,expected_error_type,expected_error_message", + [ + pytest.param( + lambda x, y: x * y, + { + "x": EncryptedScalar(UnsignedInteger(3)), + "y": EncryptedScalar(UnsignedInteger(3)), + }, + [(random.randint(0, 7), random.randint(0, 7)) for _ in range(10)] + [(7, 7)], + NotImplementedError, + "Multiplication " + "between " + "EncryptedScalar " + "and " + "EncryptedScalar " + "cannot be converted to MLIR yet", + ), + pytest.param( + lambda x, y: x - y, + { + "x": EncryptedScalar(UnsignedInteger(3)), + "y": EncryptedScalar(UnsignedInteger(3)), + }, + [(random.randint(5, 7), random.randint(0, 5)) for _ in range(10)], + NotImplementedError, + "Subtraction " + "of " + "EncryptedScalar " + "from " + "EncryptedScalar " + "cannot be converted to MLIR yet", + ), + pytest.param( + lambda x, y: numpy.dot(x, y), + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(2,)), + "y": EncryptedTensor(UnsignedInteger(3), shape=(2,)), + }, + [ + ( + numpy.random.randint(0, 2 ** 3, size=(2,)), + numpy.random.randint(0, 2 ** 3, size=(2,)), + ) + for _ in range(10) + ] + + [(numpy.array([7, 7]), numpy.array([7, 7]))], + NotImplementedError, + "Dot product " + "between " + "EncryptedTensor " + "and " + "EncryptedTensor " + "cannot be converted to MLIR yet", + ), + ], +) +def test_fail_node_conversion( + function_to_compile, + parameters, + inputset, + expected_error_type, + expected_error_message, + default_compilation_configuration, +): + """Test function for failed intermediate node conversion.""" + + with pytest.raises(expected_error_type) as excinfo: + compile_numpy_function( + function_to_compile, parameters, inputset, default_compilation_configuration + ) + + assert str(excinfo.value) == expected_error_message diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 6dc1fe5b7..40be43ddb 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -519,6 +519,8 @@ def test_compile_function_multiple_outputs( pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]), pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]), pytest.param(lambda x: -x + 50, ((0, 20),), ["x"]), + pytest.param(lambda x: numpy.dot(x, 2), ((0, 20),), ["x"]), + pytest.param(lambda x: numpy.dot(2, x), ((0, 20),), ["x"]), ], ) def test_compile_and_run_correctness( @@ -548,6 +550,11 @@ def test_compile_and_run_correctness( @pytest.mark.parametrize( "function,parameters,inputset,test_input,expected_output", [ + # TODO: find a way to support this case + # https://github.com/zama-ai/concretefhe-internal/issues/837 + # + # the problem is that compiler doesn't support combining scalars and tensors + # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x: x + 1, { @@ -566,6 +573,7 @@ def test_compile_and_run_correctness( [7, 2], [3, 6], ], + marks=pytest.mark.xfail(strict=True), ), pytest.param( lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32), @@ -590,9 +598,7 @@ def test_compile_and_run_correctness( # https://github.com/zama-ai/concretefhe-internal/issues/837 # # the problem is that compiler doesn't support combining scalars and tensors - # but they do support broadcasting, so scalars can be converted to (1,) shaped tensors - # this is easy with known constants but weird with variable things such as another input - # there is tensor.from_elements but I coudn't figure out how to use it in the python API + # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x, y: x + y, { @@ -619,7 +625,7 @@ def test_compile_and_run_correctness( [8, 3], [4, 7], ], - marks=pytest.mark.xfail(), + marks=pytest.mark.xfail(strict=True), ), pytest.param( lambda x, y: x + y, @@ -652,6 +658,11 @@ def test_compile_and_run_correctness( [5, 9], ], ), + # TODO: find a way to support this case + # https://github.com/zama-ai/concretefhe-internal/issues/837 + # + # the problem is that compiler doesn't support combining scalars and tensors + # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x: 100 - x, { @@ -670,6 +681,7 @@ def test_compile_and_run_correctness( [94, 99], [98, 95], ], + marks=pytest.mark.xfail(strict=True), ), pytest.param( lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x, @@ -690,6 +702,11 @@ def test_compile_and_run_correctness( [8, 25], ], ), + # TODO: find a way to support this case + # https://github.com/zama-ai/concretefhe-internal/issues/837 + # + # the problem is that compiler doesn't support combining scalars and tensors + # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors pytest.param( lambda x: x * 2, { @@ -708,6 +725,7 @@ def test_compile_and_run_correctness( [12, 2], [4, 10], ], + marks=pytest.mark.xfail(strict=True), ), pytest.param( lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32), @@ -747,6 +765,36 @@ def test_compile_and_run_correctness( [0, 2], ], ), + # TODO: find a way to support this case + # https://github.com/zama-ai/concretefhe-internal/issues/837 + # + # the problem is that compiler doesn't support combining scalars and tensors + # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors + pytest.param( + lambda x: numpy.dot(x, 2), + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), + }, + [(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)], + ([2, 7, 1],), + [4, 14, 2], + marks=pytest.mark.xfail(strict=True), + ), + # TODO: find a way to support this case + # https://github.com/zama-ai/concretefhe-internal/issues/837 + # + # the problem is that compiler doesn't support combining scalars and tensors + # but they do support broadcasting, so scalars should be converted to (1,) shaped tensors + pytest.param( + lambda x: numpy.dot(2, x), + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3,)), + }, + [(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)], + ([2, 7, 1],), + [4, 14, 2], + marks=pytest.mark.xfail(strict=True), + ), ], ) def test_compile_and_run_tensor_correctness( @@ -874,7 +922,7 @@ def test_compile_and_run_constant_dot_correctness( default_compilation_configuration, ) right_circuit = compile_numpy_function( - left, + right, {"x": EncryptedTensor(Integer(64, False), shape)}, inputset, default_compilation_configuration,