refactor(mlir): re-write mlir conversion

This commit is contained in:
Umut
2021-11-11 17:32:38 +03:00
parent 6fec590e65
commit 239f66eb46
15 changed files with 736 additions and 1114 deletions

View File

@@ -100,23 +100,6 @@ def value_is_unsigned_integer(value_to_check: BaseValue) -> bool:
)
def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is encrypted and is of type unsigned Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is encrypted and is of type unsigned Integer
"""
return (
value_to_check.is_encrypted
and isinstance(value_to_check.dtype, INTEGER_TYPES)
and not cast(Integer, value_to_check.dtype).is_signed
)
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is an encrypted TensorValue of type Integer.
@@ -129,22 +112,6 @@ def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
return value_is_tensor_integer(value_to_check) and value_to_check.is_encrypted
def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is an encrypted TensorValue of type unsigned Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted TensorValue of type Integer and
unsigned
"""
return (
value_is_encrypted_tensor_integer(value_to_check)
and not cast(Integer, value_to_check.dtype).is_signed
)
def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is a clear TensorValue of type Integer.

View File

@@ -1,5 +1,3 @@
"""MLIR conversion submodule."""
from .converters import V0_OPSET_CONVERSION_FUNCTIONS
from .mlir_converter import MLIRConverter
"""MLIR conversion module."""
__all__ = ["MLIRConverter", "V0_OPSET_CONVERSION_FUNCTIONS"]
from .graph_converter import OPGraphConverter

View File

@@ -0,0 +1,64 @@
"""Helpers for MLIR conversion functionality."""
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
# pylint: disable=no-name-in-module
from typing import Optional
from mlir.ir import Context, IntegerType, RankedTensorType, Type
from zamalang.dialects.hlfhe import EncryptedIntegerType
from ..data_types import Integer
from ..values import BaseValue, TensorValue
# pylint: enable=no-name-in-module
def integer_to_mlir_type(ctx: Context, integer: Integer, is_encrypted: bool) -> Optional[Type]:
"""Convert an integer to its corresponding MLIR type.
Args:
ctx (Context): the MLIR context to perform the conversion
integer (Integer): the integer to convert
is_encrypted (bool): whether the integer is encrypted or not
Returns:
Type:
the MLIR type corresponding to given integer and encryption status
if it's supported otherwise None
"""
bit_width = integer.bit_width
if is_encrypted:
result = EncryptedIntegerType.get(ctx, bit_width)
else:
result = IntegerType.get_signless(bit_width)
return result
def value_to_mlir_type(ctx: Context, value: BaseValue) -> Type:
"""Convert a value to its corresponding MLIR type.
Args:
ctx (Context): the MLIR context to perform the conversion
value (BaseValue): the value to convert
Returns:
Type: the MLIR type corresponding to given value
"""
dtype = value.dtype
if isinstance(dtype, Integer):
try:
mlir_type = integer_to_mlir_type(ctx, dtype, value.is_encrypted)
if isinstance(value, TensorValue):
if not value.is_scalar:
mlir_type = RankedTensorType.get(value.shape, mlir_type)
return mlir_type
except ValueError:
pass # the error below will be raised
raise TypeError(f"{value} is not supported for MLIR conversion")

View File

@@ -1,383 +0,0 @@
"""Converter functions from the common IR to MLIR.
Converter functions all have the same signature `converter(node, preds, ir_to_mlir_node, ctx)`
- `node`: IntermediateNode to be converted
- `preds`: List of predecessors of `node` ordered as operands
- `ir_to_mlir_node`: Dict mapping intermediate nodes to MLIR nodes or values
- `ctx`: MLIR context
"""
from typing import cast
# pylint: disable=no-name-in-module,no-member
import numpy
from mlir.dialects import arith as arith_dialect
from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
from zamalang.dialects import hlfhe, hlfhelinalg
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
value_is_clear_tensor_integer,
value_is_encrypted_tensor_integer,
value_is_encrypted_unsigned_integer,
value_is_scalar_integer,
value_is_tensor_integer,
)
from ..data_types.integers import Integer
from ..debugging.custom_assert import assert_true
from ..representation.intermediate import Add, Constant, Dot, GenericFunction, Mul, Sub
from ..values import TensorValue
def _convert_scalar_constant_op_to_single_element_tensor_constant_op(operation):
"""Convert a scalar constant operation result to a dense tensor constant operation result.
see https://github.com/zama-ai/concretefhe-internal/issues/837.
This is a temporary workaround before the compiler natively supports
`tensor + scalar`, `tensor - scalar`, `tensor * scalar` operations.
Example input = `%c3_i4 = arith.constant 3 : i4`
Example output = `%cst = arith.constant dense<3> : tensor<1xi4>`
Args:
operation: operation to convert
Returns:
the converted operation
"""
operation_str = str(operation)
constant_start_location = operation_str.find("arith.constant") + len("arith.constant") + 1
constant_end_location = operation_str.find(f": {str(operation.type)}") - 1
constant_value = operation_str[constant_start_location:constant_end_location]
resulting_type = RankedTensorType.get((1,), operation.type)
value_attr = Attribute.parse(f"dense<{constant_value}> : tensor<1x{str(operation.type)}>")
return arith_dialect.ConstantOp(resulting_type, value_attr).result
def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert an addition intermediate node."""
assert_true(len(node.inputs) == 2, "addition should have two inputs")
assert_true(len(node.outputs) == 1, "addition should have a single output")
is_convertible = True
one_of_the_inputs_is_a_tensor = False
both_of_the_inputs_are_encrypted = True
ordered_preds = preds
for input_ in node.inputs:
if value_is_tensor_integer(input_):
one_of_the_inputs_is_a_tensor = True
elif not value_is_scalar_integer(input_):
is_convertible = False
if not is_convertible:
raise TypeError(
f"Don't support addition between {str(node.inputs[0])} and {str(node.inputs[1])}"
)
if node.inputs[1].is_clear:
both_of_the_inputs_are_encrypted = False
if node.inputs[0].is_clear:
both_of_the_inputs_are_encrypted = False
ordered_preds = preds[::-1]
if one_of_the_inputs_is_a_tensor:
if both_of_the_inputs_are_encrypted:
return _linalg_add_eint_eint(node, ordered_preds, ir_to_mlir_node, ctx)
return _linalg_add_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
if both_of_the_inputs_are_encrypted:
return _add_eint_eint(node, ordered_preds, ir_to_mlir_node, ctx)
return _add_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
"""Convert an addition intermediate node with (eint, int)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.AddEintIntOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result
def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
"""Convert an addition intermediate node with (eint, eint)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.AddEintOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result
def _linalg_add_eint_int(node, preds, ir_to_mlir_node, ctx):
"""Convert an addition intermediate tensor node with (eint, int)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
if not str(rhs.type).startswith("tensor"):
rhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(rhs)
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
return hlfhelinalg.AddEintIntOp(vec_type, lhs, rhs).result
def _linalg_add_eint_eint(node, preds, ir_to_mlir_node, ctx):
"""Convert an addition intermediate tensor node with (eint, eint)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
return hlfhelinalg.AddEintOp(vec_type, lhs, rhs).result
def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a subtraction intermediate node."""
assert_true(len(node.inputs) == 2, "subtraction should have two inputs")
assert_true(len(node.outputs) == 1, "subtraction should have a single output")
is_convertible = True
one_of_the_inputs_is_a_tensor = False
if value_is_clear_tensor_integer(node.inputs[0]):
one_of_the_inputs_is_a_tensor = True
elif not value_is_clear_scalar_integer(node.inputs[0]):
is_convertible = False
if value_is_tensor_integer(node.inputs[1]):
one_of_the_inputs_is_a_tensor = True
elif not value_is_scalar_integer(node.inputs[1]):
is_convertible = False
if not is_convertible:
raise TypeError(
f"Don't support subtraction between {str(node.inputs[0])} and {str(node.inputs[1])}"
)
if one_of_the_inputs_is_a_tensor:
return _linalg_sub_int_eint(node, preds, ir_to_mlir_node, ctx)
return _sub_int_eint(node, preds, ir_to_mlir_node, ctx)
def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
"""Convert a subtraction intermediate node with (int, eint)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.SubIntEintOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result
def _linalg_sub_int_eint(node, preds, ir_to_mlir_node, ctx):
"""Convert a subtraction intermediate node with (int, eint)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
if not str(lhs.type).startswith("tensor"):
lhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(lhs)
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
return hlfhelinalg.SubIntEintOp(vec_type, lhs, rhs).result
def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a multiplication intermediate node."""
assert_true(len(node.inputs) == 2, "multiplication should have two inputs")
assert_true(len(node.outputs) == 1, "multiplication should have a single output")
is_convertible = True
one_of_the_inputs_is_a_tensor = False
ordered_preds = preds
for input_ in node.inputs:
if value_is_tensor_integer(input_):
one_of_the_inputs_is_a_tensor = True
elif not value_is_scalar_integer(input_):
is_convertible = False
if not is_convertible:
raise TypeError(
f"Don't support multiplication between {str(node.inputs[0])} and {str(node.inputs[1])}"
)
if node.inputs[0].is_clear:
ordered_preds = preds[::-1]
if one_of_the_inputs_is_a_tensor:
return _linalg_mul_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
return _mul_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
"""Convert a multiplication intermediate node with (eint, int)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.MulEintIntOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result
def _linalg_mul_eint_int(node, preds, ir_to_mlir_node, ctx):
"""Convert a subtraction intermediate node with (int, eint)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
if not str(rhs.type).startswith("tensor"):
rhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(rhs)
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
return hlfhelinalg.MulEintIntOp(vec_type, lhs, rhs).result
def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a constant input."""
value = node.outputs[0]
if value_is_clear_scalar_integer(value):
value = cast(TensorValue, value)
dtype = cast(Integer, value.dtype)
data = node.constant_data
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
return arith_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, data)).result
if value_is_clear_tensor_integer(value):
value = cast(TensorValue, value)
dtype = cast(Integer, value.dtype)
data = node.constant_data
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
vec_type = RankedTensorType.get(value.shape, int_type)
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
# provided by LLVM
# `DenseElementsAttr` should have been used instead but it's impossible to assign
# custom bit-widths using it (e.g., uint5)
# since we coudn't create a `DenseElementsAttr` with a custom bit width using python api
# we use `Attribute.parse` to let the underlying library do it by itself
value_attr = Attribute.parse(f"dense<{str(data.tolist())}> : {vec_type}")
return arith_dialect.ConstantOp(vec_type, value_attr).result
raise TypeError(f"Don't support {value} constants")
def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
"""Convert a GenericFunction intermediate node."""
variable_input_indices = [
idx for idx, pred in enumerate(preds) if not isinstance(pred, Constant)
]
assert_true(
(non_constant_pred_count := len(variable_input_indices)) == 1,
f"LUT should have a single variable input (got {non_constant_pred_count})",
)
variable_input_idx = variable_input_indices[0]
variable_input_value = node.inputs[variable_input_idx]
assert_true(len(node.outputs) == 1, "LUT should have a single output")
if not value_is_encrypted_unsigned_integer(variable_input_value):
raise TypeError(
f"Only support LUT with encrypted unsigned integers inputs "
f"(but {variable_input_value} is provided)"
)
if not value_is_encrypted_unsigned_integer(node.outputs[0]):
raise TypeError(
f"Only support LUT with encrypted unsigned integers outputs "
f"(but {node.outputs[0]} is provided)"
)
x_node = preds[variable_input_idx]
x = ir_to_mlir_node[x_node]
tables = additional_conversion_info["tables"][node]
# TODO: #559 adapt the code to support multi TLUs
# This cannot be reached today as compilation fails if the intermediate values are not all
# scalars
if len(tables) > 1: # pragma: no cover
raise RuntimeError(
"MLIR conversion currently does not support multiple test vectors for LUT"
)
table = tables[0][0]
out_dtype = cast(Integer, node.outputs[0].dtype)
# Create table
dense_elem = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=ctx)
tensor_lut = arith_dialect.ConstantOp(
RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)),
dense_elem,
).result
int_type = hlfhe.EncryptedIntegerType.get(ctx, out_dtype.bit_width)
if value_is_encrypted_tensor_integer(node.inputs[0]):
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
return hlfhelinalg.ApplyLookupTableEintOp(vec_type, x, tensor_lut).result
return hlfhe.ApplyLookupTableEintOp(int_type, x, tensor_lut).result
def dot(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a dot intermediate node."""
assert_true(len(node.inputs) == 2, "Dot should have two inputs")
assert_true(len(node.outputs) == 1, "Dot should have a single output")
if not (
(
value_is_encrypted_tensor_integer(node.inputs[0])
and value_is_clear_tensor_integer(node.inputs[1])
)
or (
value_is_encrypted_tensor_integer(node.inputs[1])
and value_is_clear_tensor_integer(node.inputs[0])
)
):
raise TypeError(
f"Don't support dot between {str(node.inputs[0])} and {str(node.inputs[1])}"
)
lhs_node, rhs_node = preds
# need to flip as underlying operation need encrypted first
if value_is_clear_tensor_integer(node.inputs[0]):
lhs_node, rhs_node = rhs_node, lhs_node
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhelinalg.Dot(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
lhs,
rhs,
).result
V0_OPSET_CONVERSION_FUNCTIONS = {
Add: add,
Sub: sub,
Mul: mul,
Constant: constant,
GenericFunction: apply_lut,
Dot: dot,
}
# pylint: enable=no-name-in-module,no-member

View File

@@ -0,0 +1,79 @@
"""Module that provides OPGraph conversion functionality."""
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
# pylint: disable=no-name-in-module
from abc import ABC, abstractmethod
from typing import Any, Dict
import networkx as nx
import zamalang
from mlir.dialects import builtin
from mlir.ir import Context, InsertionPoint, Location, Module
from ..operator_graph import OPGraph
from ..representation.intermediate import Input
from .conversion_helpers import value_to_mlir_type
from .node_converter import IntermediateNodeConverter
# pylint: enable=no-name-in-module
class OPGraphConverter(ABC):
"""Converter of OPGraph to MLIR."""
def convert(self, opgraph: OPGraph) -> str:
"""Convert an operation graph to its corresponding MLIR representation.
Args:
opgraph (OPGraph): the operation graph to be converted
Returns:
str: textual MLIR representation corresponding to opgraph
"""
additional_conversion_info = self._generate_additional_info_dict(opgraph)
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
module = Module.create()
with InsertionPoint(module.body):
parameters = [
value_to_mlir_type(ctx, input_node.outputs[0])
for input_node in opgraph.get_ordered_inputs()
]
@builtin.FuncOp.from_py_func(*parameters)
def main(*arg):
ir_to_mlir = {}
for arg_num, node in opgraph.input_nodes.items():
ir_to_mlir[node] = arg[arg_num]
for node in nx.topological_sort(opgraph.graph):
if isinstance(node, Input):
continue
preds = [ir_to_mlir[pred] for pred in opgraph.get_ordered_preds(node)]
node_converter = IntermediateNodeConverter(ctx, opgraph, node, preds)
ir_to_mlir[node] = node_converter.convert(additional_conversion_info)
results = (
ir_to_mlir[output_node] for output_node in opgraph.get_ordered_outputs()
)
return results
return str(module)
@staticmethod
@abstractmethod
def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]:
"""Generate additional conversion info dict for the MLIR converter.
Args:
opgraph (OPGraph): the operation graph from which the additional info will be generated
Returns:
Dict[str, Any]: dict of additional conversion info
"""

View File

@@ -1,194 +0,0 @@
"""File containing code to convert a DAG containing ir nodes to the compiler opset."""
# pylint: disable=no-name-in-module,no-member
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, cast
import networkx as nx
import zamalang
from mlir.dialects import builtin
from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module, RankedTensorType
from mlir.ir import Type as MLIRType
from zamalang.dialects import hlfhe
from .. import values
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
value_is_clear_tensor_integer,
value_is_encrypted_scalar_unsigned_integer,
value_is_encrypted_tensor_unsigned_integer,
)
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Input
class MLIRConverter(ABC):
"""Converter of the common IR to MLIR."""
def __init__(self, conversion_functions: dict) -> None:
"""Instantiate a converter with a given set of converters.
Args:
conversion_functions (dict): mapping common IR nodes to functions that generate MLIR.
every function should have 4 arguments:
- node (IntermediateNode): the node itself to be converted
- operands (IntermediateNode): predecessors of node ordered as operands
- ir_to_mlir_node (dict): mapping between IntermediateNode and their equivalent
MLIR values
- context (mlir.Context): the MLIR context being used for the conversion
"""
self.conversion_functions = conversion_functions
self._init_context()
def _init_context(self):
self.context = Context()
zamalang.register_dialects(self.context)
def _get_tensor_type(
self,
bit_width: int,
is_encrypted: bool,
is_signed: bool,
shape: Tuple[int, ...],
) -> MLIRType:
"""Get the MLIRType for a tensor element given its properties.
Args:
bit_width (int): number of bits used for the scalar
is_encrypted (bool): is the scalar encrypted or not
is_signed (bool): is the scalar signed or not
shape (Tuple[int, ...]): shape of the tensor
Returns:
MLIRType: corresponding MLIR type
"""
element_type = self._get_scalar_integer_type(bit_width, is_encrypted, is_signed)
return RankedTensorType.get(shape, element_type)
def _get_scalar_integer_type(
self, bit_width: int, is_encrypted: bool, is_signed: bool
) -> MLIRType:
"""Get the MLIRType for a scalar element given its properties.
Args:
bit_width (int): number of bits used for the scalar
is_encrypted (bool): is the scalar encrypted or not
is_signed (bool): is the scalar signed or not
Returns:
MLIRType: corresponding MLIR type
"""
if is_encrypted and not is_signed:
return hlfhe.EncryptedIntegerType.get(self.context, bit_width)
if is_signed and not is_encrypted: # clear signed
return IntegerType.get_signed(bit_width)
# should be clear unsigned at this point
assert_true(not is_signed and not is_encrypted)
# unsigned integer are considered signless in the compiler
return IntegerType.get_signless(bit_width)
@staticmethod
@abstractmethod
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
"""Generate the additional_conversion_info dict for the MLIR converter.
Args:
op_graph (OPGraph): the OPGraph for which we need the conversion infos.
Returns:
Dict[str, Any]: The dict with the additional conversion infos.
"""
def common_value_to_mlir_type(self, value: values.BaseValue) -> MLIRType:
"""Convert a common value to its corresponding MLIR Type.
Args:
value: value to convert
Returns:
corresponding MLIR type
"""
if value_is_encrypted_scalar_unsigned_integer(value):
return self._get_scalar_integer_type(cast(Integer, value.dtype).bit_width, True, False)
if value_is_clear_scalar_integer(value):
dtype = cast(Integer, value.dtype)
return self._get_scalar_integer_type(
dtype.bit_width, is_encrypted=False, is_signed=dtype.is_signed
)
if value_is_encrypted_tensor_unsigned_integer(value):
dtype = cast(Integer, value.dtype)
return self._get_tensor_type(
dtype.bit_width,
is_encrypted=True,
is_signed=False,
shape=cast(values.TensorValue, value).shape,
)
if value_is_clear_tensor_integer(value):
dtype = cast(Integer, value.dtype)
return self._get_tensor_type(
dtype.bit_width,
is_encrypted=False,
is_signed=dtype.is_signed,
shape=cast(values.TensorValue, value).shape,
)
raise TypeError(f"can't convert value of type {type(value)} to MLIR type")
def convert(self, op_graph: OPGraph) -> str:
"""Convert the graph of IntermediateNode to an MLIR textual representation.
Args:
op_graph (OPGraph): graph of IntermediateNode to be converted
Raises:
NotImplementedError: raised if an unknown node type is encountered.
Returns:
str: textual MLIR representation
"""
additional_conversion_info = self._generate_additional_info_dict(op_graph)
with self.context, Location.unknown():
module = Module.create()
# collect inputs
with InsertionPoint(module.body):
func_types = [
self.common_value_to_mlir_type(input_node.inputs[0])
for input_node in op_graph.get_ordered_inputs()
]
@builtin.FuncOp.from_py_func(*func_types)
def main(*arg):
ir_to_mlir_node = {}
for arg_num, node in op_graph.input_nodes.items():
ir_to_mlir_node[node] = arg[arg_num]
for node in nx.topological_sort(op_graph.graph):
if isinstance(node, Input):
continue
mlir_op = self.conversion_functions.get(type(node), None)
if mlir_op is None: # pragma: no cover
raise NotImplementedError(
f"we don't yet support conversion to MLIR of computations using"
f"{type(node)}"
)
preds = op_graph.get_ordered_preds(node)
# convert to mlir
result = mlir_op(
node,
preds,
ir_to_mlir_node,
self.context,
additional_conversion_info,
)
ir_to_mlir_node[node] = result
results = (
ir_to_mlir_node[output_node]
for output_node in op_graph.get_ordered_outputs()
)
return results
return module.__str__()
# pylint: enable=no-name-in-module,no-member

View File

@@ -0,0 +1,325 @@
"""Module that provides IntermediateNode conversion functionality."""
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
# pylint: disable=no-name-in-module
from typing import Any, Dict, List, cast
import numpy
from mlir.dialects import arith
from mlir.ir import (
Attribute,
Context,
DenseElementsAttr,
IntegerAttr,
IntegerType,
OpResult,
RankedTensorType,
)
from zamalang.dialects import hlfhe, hlfhelinalg
from ..data_types import Integer
from ..debugging import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import (
Add,
Constant,
Dot,
GenericFunction,
IntermediateNode,
Mul,
Sub,
)
from ..values import TensorValue
from .conversion_helpers import value_to_mlir_type
# pylint: enable=no-name-in-module
class IntermediateNodeConverter:
"""Converter of IntermediateNode to MLIR."""
ctx: Context
opgraph: OPGraph
node: IntermediateNode
preds: List[OpResult]
all_of_the_inputs_are_encrypted: bool
all_of_the_inputs_are_tensors: bool
one_of_the_inputs_is_a_tensor: bool
def __init__(
self, ctx: Context, opgraph: OPGraph, node: IntermediateNode, preds: List[OpResult]
):
self.ctx = ctx
self.opgraph = opgraph
self.node = node
self.preds = preds
self.all_of_the_inputs_are_encrypted = True
self.all_of_the_inputs_are_tensors = True
self.one_of_the_inputs_is_a_tensor = False
for inp in node.inputs:
if inp.is_clear:
self.all_of_the_inputs_are_encrypted = False
if isinstance(inp, TensorValue):
if inp.is_scalar:
self.all_of_the_inputs_are_tensors = False
else:
self.one_of_the_inputs_is_a_tensor = True
else: # pragma: no cover
# this branch is not covered as there are only TensorValues for now
self.all_of_the_inputs_are_tensors = False
def convert(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
"""Convert an intermediate node to its corresponding MLIR representation.
Args:
additional_conversion_info (Dict[str, Any]):
external info that the converted node might need
Returns:
str: textual MLIR representation corresponding to self.node
"""
if isinstance(self.node, Add):
return self.convert_add()
if isinstance(self.node, Constant):
return self.convert_constant()
if isinstance(self.node, Dot):
return self.convert_dot()
if isinstance(self.node, GenericFunction):
return self.convert_generic_function(additional_conversion_info)
if isinstance(self.node, Mul):
return self.convert_mul()
if isinstance(self.node, Sub):
return self.convert_sub()
# this statement is not covered as unsupported opeations fail on check mlir compatibility
raise NotImplementedError(
f"{type(self.node)} nodes cannot be converted to MLIR yet"
) # pragma: no cover
def convert_add(self) -> OpResult:
"""Convert an Add node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2)
assert_true(len(self.node.outputs) == 1)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
if self.all_of_the_inputs_are_encrypted:
if self.one_of_the_inputs_is_a_tensor:
result = hlfhelinalg.AddEintOp(resulting_type, *preds).result
else:
result = hlfhe.AddEintOp(resulting_type, *preds).result
else:
if self.node.inputs[0].is_clear: # pragma: no cover
# this branch is not covered as it's impossible to get into due to how tracing works
# however, it doesn't hurt to keep it as an extra measure
preds = preds[::-1]
if self.one_of_the_inputs_is_a_tensor:
result = hlfhelinalg.AddEintIntOp(resulting_type, *preds).result
else:
result = hlfhe.AddEintIntOp(resulting_type, *preds).result
return result
def convert_constant(self) -> OpResult:
"""Convert a Constant node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 0)
assert_true(len(self.node.outputs) == 1)
value = self.node.outputs[0]
if not isinstance(value, TensorValue): # pragma: no cover
# this branch is not covered as there are only TensorValues for now
raise NotImplementedError(f"{value} constants cannot be converted to MLIR yet")
resulting_type = value_to_mlir_type(self.ctx, value)
data = cast(Constant, self.node).constant_data
if value.is_scalar:
attr = IntegerAttr.get(resulting_type, data)
else:
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
# provided by LLVM
# what should have been used is `DenseElementsAttr` but it's impossible to assign
# custom bit-widths using it (e.g., uint5)
# since we coudn't create a `DenseElementsAttr` with a custom bit width using python api
# we use `Attribute.parse` to let the underlying library do it by itself
attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}")
return arith.ConstantOp(resulting_type, attr).result
def convert_dot(self) -> OpResult:
"""Convert a Dot node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2)
assert_true(len(self.node.outputs) == 1)
if self.all_of_the_inputs_are_encrypted:
lhs = self.node.inputs[0]
rhs = self.node.inputs[1]
raise NotImplementedError(
f"Dot product between {lhs} and {rhs} cannot be converted to MLIR yet",
)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
if self.node.inputs[0].is_clear:
preds = preds[::-1]
if self.all_of_the_inputs_are_tensors:
# numpy.dot(x, y) where x and y are both vectors = regular dot product
result = hlfhelinalg.Dot(resulting_type, *preds).result
elif not self.one_of_the_inputs_is_a_tensor:
# numpy.dot(x, y) where x and y are both scalars = x * y
result = hlfhe.MulEintIntOp(resulting_type, *preds).result
else:
# numpy.dot(x, y) where one of x or y is a scalar and the other one is a vector = x * y
result = hlfhelinalg.MulEintIntOp(resulting_type, *preds).result
return result
def convert_generic_function(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
"""Convert a GenericFunction node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
variable_input_indices = [
idx
for idx, inp in enumerate(self.opgraph.get_ordered_preds(self.node))
if not isinstance(inp, Constant)
]
if len(variable_input_indices) != 1: # pragma: no cover
# this branch is not covered as it's impossible to get into due to how tracing works
# however, it doesn't hurt to keep it as an extra measure
raise NotImplementedError(
"Table lookups with more than one variable input cannot be converted to MLIR yet"
)
variable_input_index = variable_input_indices[0]
assert_true(len(self.node.outputs) == 1)
value = self.node.inputs[variable_input_index]
assert_true(value.is_encrypted)
if not isinstance(value.dtype, Integer) or value.dtype.is_signed: # pragma: no cover
# this branch is not covered as it's impossible to get into due to how compilation works
# however, it doesn't hurt to keep it as an extra measure
raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet")
tables = additional_conversion_info["tables"][self.node]
# TODO: #559 adapt the code to support multi TLUs
# This cannot be reached today as compilation fails
# if the intermediate values are not all scalars
if len(tables) > 1: # pragma: no cover
raise NotImplementedError("Multi table lookups cannot be converted to MLIR yet")
table = tables[0][0]
lut_size = len(table)
lut_type = RankedTensorType.get([lut_size], IntegerType.get_signless(64, context=self.ctx))
lut_attr = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=self.ctx)
lut = arith.ConstantOp(lut_type, lut_attr).result
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
pred = self.preds[variable_input_index]
if self.one_of_the_inputs_is_a_tensor:
result = hlfhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
else:
result = hlfhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
return result
def convert_mul(self) -> OpResult:
"""Convert a Mul node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2)
assert_true(len(self.node.outputs) == 1)
if self.all_of_the_inputs_are_encrypted:
lhs = self.node.inputs[0]
rhs = self.node.inputs[1]
raise NotImplementedError(
f"Multiplication between {lhs} and {rhs} cannot be converted to MLIR yet",
)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
if self.node.inputs[0].is_clear: # pragma: no cover
# this branch is not covered as it's impossible to get into due to how tracing works
# however, it doesn't hurt to keep it as an extra measure
preds = preds[::-1]
if self.one_of_the_inputs_is_a_tensor:
result = hlfhelinalg.MulEintIntOp(resulting_type, *preds).result
else:
result = hlfhe.MulEintIntOp(resulting_type, *preds).result
return result
def convert_sub(self) -> OpResult:
"""Convert a Sub node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 2)
assert_true(len(self.node.outputs) == 1)
lhs = self.node.inputs[0]
rhs = self.node.inputs[1]
if not (lhs.is_clear and rhs.is_encrypted):
raise NotImplementedError(
f"Subtraction of {rhs} from {lhs} cannot be converted to MLIR yet",
)
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
preds = self.preds
if self.one_of_the_inputs_is_a_tensor:
result = hlfhelinalg.SubIntEintOp(resulting_type, *preds).result
else:
result = hlfhe.SubIntEintOp(resulting_type, *preds).result
return result

View File

@@ -475,19 +475,20 @@ class Dot(IntermediateNode):
assert_true(
all(
isinstance(input_value, TensorValue) and input_value.ndim == 1
isinstance(input_value, TensorValue) and input_value.ndim <= 1
for input_value in self.inputs
),
f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)",
f"Dot only supports two scalars or vectors ({TensorValue.__name__} with ndim up to 1)",
)
lhs = cast(TensorValue, self.inputs[0])
rhs = cast(TensorValue, self.inputs[1])
assert_true(
lhs.shape[0] == rhs.shape[0],
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
)
if lhs.ndim == 1 and rhs.ndim == 1:
assert_true(
lhs.shape[0] == rhs.shape[0],
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
)
output_scalar_value = (
EncryptedScalar if (lhs.is_encrypted or rhs.is_encrypted) else ClearScalar

View File

@@ -14,7 +14,6 @@ from ..common.data_types import Integer
from ..common.debugging import format_operation_graph
from ..common.debugging.custom_assert import assert_true
from ..common.fhe_circuit import FHECircuit
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
from ..common.mlir.utils import (
check_graph_values_compatibility_with_mlir,
extend_direct_lookup_tables,
@@ -312,7 +311,7 @@ def _compile_numpy_function_internal(
prepare_op_graph_for_mlir(op_graph)
# Convert graph to an MLIR representation
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
converter = NPMLIRConverter()
mlir_result = converter.convert(op_graph)
# Show MLIR representation if requested

View File

@@ -8,7 +8,7 @@ from typing import Any, DefaultDict, Dict, List, Tuple
import numpy
from ..common.debugging import assert_true
from ..common.mlir.mlir_converter import MLIRConverter
from ..common.mlir.graph_converter import OPGraphConverter
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import GenericFunction, IntermediateNode
@@ -67,27 +67,18 @@ def generate_deduplicated_tables(
)
class NPMLIRConverter(MLIRConverter):
class NPMLIRConverter(OPGraphConverter):
"""Numpy-specific MLIR converter."""
@staticmethod
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
"""Generate the additional_conversion_info dict for the MLIR converter.
Args:
op_graph (OPGraph): the OPGraph for which we need the conversion infos.
Returns:
Dict[str, Any]: The dict with the additional conversion infos.
"""
def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]:
additional_conversion_info = {}
# Disable numpy warnings during conversion to avoid issues during TLU generation
with numpy.errstate(all="ignore"):
additional_conversion_info["tables"] = {
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
for node in op_graph.graph.nodes()
node: generate_deduplicated_tables(node, opgraph.get_ordered_preds(node))
for node in opgraph.graph.nodes()
if isinstance(node, GenericFunction)
}

View File

@@ -0,0 +1,115 @@
"""Test file for MLIR conversion helpers."""
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
# pylint: disable=no-name-in-module
import pytest
import zamalang
from mlir.ir import Context, Location
from concrete.common.data_types import Float, SignedInteger, UnsignedInteger
from concrete.common.mlir.conversion_helpers import integer_to_mlir_type, value_to_mlir_type
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
# pylint: enable=no-name-in-module
@pytest.mark.parametrize(
"integer,is_encrypted,expected_mlir_type_str",
[
pytest.param(SignedInteger(5), False, "i5"),
pytest.param(UnsignedInteger(5), False, "i5"),
pytest.param(SignedInteger(32), False, "i32"),
pytest.param(UnsignedInteger(32), False, "i32"),
pytest.param(SignedInteger(5), True, "!HLFHE.eint<5>"),
pytest.param(UnsignedInteger(5), True, "!HLFHE.eint<5>"),
],
)
def test_integer_to_mlir_type(integer, is_encrypted, expected_mlir_type_str):
"""Test function for integer to MLIR type conversion."""
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
assert str(integer_to_mlir_type(ctx, integer, is_encrypted)) == expected_mlir_type_str
@pytest.mark.parametrize(
"integer,is_encrypted,expected_error_message",
[
pytest.param(SignedInteger(32), True, "can't create eint with the given width"),
pytest.param(UnsignedInteger(32), True, "can't create eint with the given width"),
],
)
def test_fail_integer_to_mlir_type(integer, is_encrypted, expected_error_message):
"""Test function for failed integer to MLIR type conversion."""
with pytest.raises(ValueError) as excinfo:
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
integer_to_mlir_type(ctx, integer, is_encrypted)
assert str(excinfo.value) == expected_error_message
@pytest.mark.parametrize(
"value,expected_mlir_type_str",
[
pytest.param(ClearScalar(SignedInteger(5)), "i5"),
pytest.param(ClearTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
pytest.param(EncryptedScalar(SignedInteger(5)), "!HLFHE.eint<5>"),
pytest.param(EncryptedTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3x!HLFHE.eint<5>>"),
pytest.param(ClearScalar(UnsignedInteger(5)), "i5"),
pytest.param(ClearTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
pytest.param(EncryptedScalar(UnsignedInteger(5)), "!HLFHE.eint<5>"),
pytest.param(
EncryptedTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3x!HLFHE.eint<5>>"
),
],
)
def test_value_to_mlir_type(value, expected_mlir_type_str):
"""Test function for value to MLIR type conversion."""
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
assert str(value_to_mlir_type(ctx, value)) == expected_mlir_type_str
@pytest.mark.parametrize(
"value,expected_error_message",
[
pytest.param(
ClearScalar(Float(32)),
"ClearScalar<float32> is not supported for MLIR conversion",
),
pytest.param(
ClearTensor(Float(32), shape=(2, 3)),
"ClearTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
),
pytest.param(
EncryptedScalar(Float(32)),
"EncryptedScalar<float32> is not supported for MLIR conversion",
),
pytest.param(
EncryptedTensor(Float(32), shape=(2, 3)),
"EncryptedTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
),
pytest.param(
EncryptedScalar(UnsignedInteger(32)),
"EncryptedScalar<uint32> is not supported for MLIR conversion",
),
pytest.param(
EncryptedTensor(UnsignedInteger(32), shape=(2, 3)),
"EncryptedTensor<uint32, shape=(2, 3)> is not supported for MLIR conversion",
),
],
)
def test_fail_value_to_mlir_type(value, expected_error_message):
"""Test function for failed value to MLIR type conversion."""
with pytest.raises(TypeError) as excinfo:
with Context() as ctx, Location.unknown():
zamalang.register_dialects(ctx)
value_to_mlir_type(ctx, value)
assert str(excinfo.value) == expected_error_message

View File

@@ -1,81 +0,0 @@
"""Test converter functions"""
import pytest
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.mlir.converters import add, apply_lut, constant, dot, mul, sub
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar
class MockNode:
"""Mocking an intermediate node"""
def __init__(self, inputs_n=5, outputs_n=5, inputs=None, outputs=None):
if inputs is None:
self.inputs = [None for i in range(inputs_n)]
else:
self.inputs = inputs
if outputs is None:
self.outputs = [None for i in range(outputs_n)]
else:
self.outputs = outputs
@pytest.mark.parametrize("converter", [add, sub, mul, dot])
def test_failing_converter(converter):
"""Test failing converter"""
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
converter(MockNode(2, 1), None, None, None)
def test_fail_non_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support .* constants"):
constant(MockNode(outputs=[ClearScalar(Float(32))]), None, None, None)
with pytest.raises(TypeError, match=r"Don't support .* constants"):
constant(MockNode(outputs=[ClearTensor(Float(32), shape=(2,))]), None, None, None)
@pytest.mark.parametrize(
"input_node",
[
ClearScalar(Integer(8, True)),
ClearScalar(Integer(8, False)),
EncryptedScalar(Integer(8, True)),
],
)
def test_fail_tlu_input(input_node):
"""Test failing LUT converter with invalid input"""
with pytest.raises(
TypeError, match=r"Only support LUT with encrypted unsigned integers inputs"
):
apply_lut(
MockNode(inputs=[input_node], outputs=[EncryptedScalar(Integer(8, False))]),
[None],
None,
None,
None,
)
@pytest.mark.parametrize(
"input_node",
[
ClearScalar(Integer(8, True)),
ClearScalar(Integer(8, False)),
EncryptedScalar(Integer(8, True)),
],
)
def test_fail_tlu_output(input_node):
"""Test failing LUT converter with invalid output"""
with pytest.raises(
TypeError, match=r"Only support LUT with encrypted unsigned integers outputs"
):
apply_lut(
MockNode(inputs=[EncryptedScalar(Integer(8, False))], outputs=[input_node]),
[None],
None,
None,
None,
)

View File

@@ -1,392 +0,0 @@
"""Test file for conversion to MLIR"""
# pylint: disable=no-name-in-module,no-member
import itertools
import numpy
import pytest
from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType
from zamalang import compiler
from zamalang.dialects import hlfhe
from concrete.common.data_types.integers import Integer
from concrete.common.extensions.table import LookupTable
from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
from concrete.common.values import ClearScalar, EncryptedScalar
from concrete.common.values.tensors import ClearTensor, EncryptedTensor
from concrete.numpy.compile import compile_numpy_function_into_op_graph, prepare_op_graph_for_mlir
from concrete.numpy.np_mlir_converter import NPMLIRConverter
def add(x, y):
"""Test simple add"""
return x + y
def constant_add(x):
"""Test constant add"""
return x + 5
def sub(x, y):
"""Test simple sub"""
return x - y
def constant_sub(x):
"""Test constant sub"""
return 12 - x
def mul(x, y):
"""Test simple mul"""
return x * y
def constant_mul(x):
"""Test constant mul"""
return x * 2
def sub_add_mul(x, y, z):
"""Test combination of ops"""
return z - y + x * z
def ret_multiple(x, y, z):
"""Test return of multiple values"""
return x, y, z
def ret_multiple_different_order(x, y, z):
"""Test return of multiple values in a different order from input"""
return y, z, x
def lut(x):
"""Test lookup table"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
return table[x]
# TODO: remove workaround #359
def lut_more_bits_than_table_length(x, y):
"""Test lookup table when bit_width support longer LUT"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
return table[x] + y
# TODO: remove workaround #359
def lut_less_bits_than_table_length(x):
"""Test lookup table when bit_width support smaller LUT"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7, 3, 6, 0, 2, 1, 4, 5, 7])
return table[x]
def dot(x, y):
"""Test dot"""
return numpy.dot(x, y)
def datagen(*args):
"""Generate data from ranges"""
for prod in itertools.product(*args):
yield prod
@pytest.mark.parametrize(
"func, args_dict, args_ranges",
[
(
add,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
"y": ClearScalar(Integer(32, is_signed=False)),
},
(range(0, 8), range(1, 4)),
),
(
constant_add,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 10),),
),
(
add,
{
"x": ClearScalar(Integer(32, is_signed=False)),
"y": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 8), range(1, 4)),
),
(
add,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
},
(range(7, 15), range(1, 5)),
),
(
sub,
{
"x": ClearScalar(Integer(8, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
},
(range(5, 10), range(2, 6)),
),
(
constant_sub,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 10),),
),
(
mul,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": ClearScalar(Integer(8, is_signed=False)),
},
(range(1, 5), range(2, 8)),
),
(
constant_mul,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 10),),
),
(
mul,
{
"x": ClearScalar(Integer(8, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
},
(range(1, 5), range(2, 8)),
),
(
sub_add_mul,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
"z": ClearScalar(Integer(7, is_signed=False)),
},
(range(0, 8), range(1, 5), range(5, 12)),
),
(
ret_multiple,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
"z": ClearScalar(Integer(7, is_signed=False)),
},
(range(1, 5), range(1, 5), range(1, 5)),
),
(
ret_multiple_different_order,
{
"x": EncryptedScalar(Integer(7, is_signed=False)),
"y": EncryptedScalar(Integer(7, is_signed=False)),
"z": ClearScalar(Integer(7, is_signed=False)),
},
(range(1, 5), range(1, 5), range(1, 5)),
),
(
lut,
{
"x": EncryptedScalar(Integer(3, is_signed=False)),
},
(range(0, 8),),
),
(
lut_more_bits_than_table_length,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
"y": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 8), range(0, 16)),
),
(
lut_less_bits_than_table_length,
{
"x": EncryptedScalar(Integer(3, is_signed=False)),
},
(range(0, 8),),
),
],
)
def test_mlir_converter(func, args_dict, args_ranges, default_compilation_configuration):
"""Test the conversion to MLIR by calling the parser from the compiler"""
inputset = datagen(*args_ranges)
result_graph = compile_numpy_function_into_op_graph(
func,
args_dict,
inputset,
default_compilation_configuration,
)
prepare_op_graph_for_mlir(result_graph)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error
compiler.round_trip(mlir_result)
@pytest.mark.parametrize(
"func, args_dict, args_ranges",
[
(
dot,
{
"x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
"y": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 4), range(0, 4)),
),
(
dot,
{
"x": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
"y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 4), range(0, 4)),
),
],
)
def test_mlir_converter_dot_between_vectors(
func, args_dict, args_ranges, default_compilation_configuration
):
"""Test the conversion to MLIR by calling the parser from the compiler"""
assert len(args_dict["x"].shape) == 1
assert len(args_dict["y"].shape) == 1
n = args_dict["x"].shape[0]
result_graph = compile_numpy_function_into_op_graph(
func,
args_dict,
(
(numpy.array([data[0]] * n), numpy.array([data[1]] * n))
for data in datagen(*args_ranges)
),
default_compilation_configuration,
)
prepare_op_graph_for_mlir(result_graph)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error
compiler.round_trip(mlir_result)
def test_mlir_converter_dot_vector_and_constant(default_compilation_configuration):
"""Test the conversion to MLIR by calling the parser from the compiler"""
def left_dot_with_constant(x):
return numpy.dot(x, numpy.array([1, 2]))
def right_dot_with_constant(x):
return numpy.dot(numpy.array([1, 2]), x)
left_graph = compile_numpy_function_into_op_graph(
left_dot_with_constant,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
default_compilation_configuration,
)
prepare_op_graph_for_mlir(left_graph)
left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
left_mlir = left_converter.convert(left_graph)
right_graph = compile_numpy_function_into_op_graph(
right_dot_with_constant,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
default_compilation_configuration,
)
prepare_op_graph_for_mlir(right_graph)
right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
right_mlir = right_converter.convert(right_graph)
# testing that this doesn't raise an error
compiler.round_trip(left_mlir)
compiler.round_trip(right_mlir)
def test_concrete_encrypted_integer_to_mlir_type():
"""Test conversion of EncryptedScalar into MLIR"""
value = EncryptedScalar(Integer(7, is_signed=False))
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
eint = converter.common_value_to_mlir_type(value)
assert eint == hlfhe.EncryptedIntegerType.get(converter.context, 7)
@pytest.mark.parametrize("is_signed", [True, False])
def test_concrete_clear_integer_to_mlir_type(is_signed):
"""Test conversion of ClearScalar into MLIR"""
value = ClearScalar(Integer(5, is_signed=is_signed))
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context:
int_mlir = converter.common_value_to_mlir_type(value)
if is_signed:
assert int_mlir == IntegerType.get_signed(5)
else:
assert int_mlir == IntegerType.get_signless(5)
@pytest.mark.parametrize("is_signed", [True, False])
@pytest.mark.parametrize(
"shape",
[
(5,),
(5, 8),
(-1, 5),
],
)
def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape):
"""Test conversion of ClearTensor into MLIR"""
value = ClearTensor(Integer(5, is_signed=is_signed), shape)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context, Location.unknown():
tensor_mlir = converter.common_value_to_mlir_type(value)
if is_signed:
element_type = IntegerType.get_signed(5)
else:
element_type = IntegerType.get_signless(5)
if shape is None:
expected_type = UnrankedTensorType.get(element_type)
else:
expected_type = RankedTensorType.get(shape, element_type)
assert tensor_mlir == expected_type
@pytest.mark.parametrize(
"shape",
[
(5,),
(5, 8),
(-1, 5),
],
)
def test_concrete_encrypted_tensor_integer_to_mlir_type(shape):
"""Test conversion of EncryptedTensor into MLIR"""
value = EncryptedTensor(Integer(6, is_signed=False), shape)
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with converter.context, Location.unknown():
tensor_mlir = converter.common_value_to_mlir_type(value)
element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6)
if shape is None:
expected_type = UnrankedTensorType.get(element_type)
else:
expected_type = RankedTensorType.get(shape, element_type)
assert tensor_mlir == expected_type
def test_failing_concrete_to_mlir_type():
"""Test failing conversion of an unsupported type into MLIR"""
value = "random"
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with pytest.raises(TypeError, match=r"can't convert value of type .* to MLIR type"):
converter.common_value_to_mlir_type(value)
# pylint: enable=no-name-in-module,no-member

View File

@@ -0,0 +1,85 @@
"""Test file for intermediate node to MLIR converter."""
import random
import numpy
import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import compile_numpy_function
@pytest.mark.parametrize(
"function_to_compile,parameters,inputset,expected_error_type,expected_error_message",
[
pytest.param(
lambda x, y: x * y,
{
"x": EncryptedScalar(UnsignedInteger(3)),
"y": EncryptedScalar(UnsignedInteger(3)),
},
[(random.randint(0, 7), random.randint(0, 7)) for _ in range(10)] + [(7, 7)],
NotImplementedError,
"Multiplication "
"between "
"EncryptedScalar<uint6> "
"and "
"EncryptedScalar<uint6> "
"cannot be converted to MLIR yet",
),
pytest.param(
lambda x, y: x - y,
{
"x": EncryptedScalar(UnsignedInteger(3)),
"y": EncryptedScalar(UnsignedInteger(3)),
},
[(random.randint(5, 7), random.randint(0, 5)) for _ in range(10)],
NotImplementedError,
"Subtraction "
"of "
"EncryptedScalar<uint3> "
"from "
"EncryptedScalar<uint3> "
"cannot be converted to MLIR yet",
),
pytest.param(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
"y": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
},
[
(
numpy.random.randint(0, 2 ** 3, size=(2,)),
numpy.random.randint(0, 2 ** 3, size=(2,)),
)
for _ in range(10)
]
+ [(numpy.array([7, 7]), numpy.array([7, 7]))],
NotImplementedError,
"Dot product "
"between "
"EncryptedTensor<uint7, shape=(2,)> "
"and "
"EncryptedTensor<uint7, shape=(2,)> "
"cannot be converted to MLIR yet",
),
],
)
def test_fail_node_conversion(
function_to_compile,
parameters,
inputset,
expected_error_type,
expected_error_message,
default_compilation_configuration,
):
"""Test function for failed intermediate node conversion."""
with pytest.raises(expected_error_type) as excinfo:
compile_numpy_function(
function_to_compile, parameters, inputset, default_compilation_configuration
)
assert str(excinfo.value) == expected_error_message

View File

@@ -519,6 +519,8 @@ def test_compile_function_multiple_outputs(
pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]),
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
pytest.param(lambda x: -x + 50, ((0, 20),), ["x"]),
pytest.param(lambda x: numpy.dot(x, 2), ((0, 20),), ["x"]),
pytest.param(lambda x: numpy.dot(2, x), ((0, 20),), ["x"]),
],
)
def test_compile_and_run_correctness(
@@ -548,6 +550,11 @@ def test_compile_and_run_correctness(
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
[
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: x + 1,
{
@@ -566,6 +573,7 @@ def test_compile_and_run_correctness(
[7, 2],
[3, 6],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
@@ -590,9 +598,7 @@ def test_compile_and_run_correctness(
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars can be converted to (1,) shaped tensors
# this is easy with known constants but weird with variable things such as another input
# there is tensor.from_elements but I coudn't figure out how to use it in the python API
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x, y: x + y,
{
@@ -619,7 +625,7 @@ def test_compile_and_run_correctness(
[8, 3],
[4, 7],
],
marks=pytest.mark.xfail(),
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x, y: x + y,
@@ -652,6 +658,11 @@ def test_compile_and_run_correctness(
[5, 9],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: 100 - x,
{
@@ -670,6 +681,7 @@ def test_compile_and_run_correctness(
[94, 99],
[98, 95],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
@@ -690,6 +702,11 @@ def test_compile_and_run_correctness(
[8, 25],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: x * 2,
{
@@ -708,6 +725,7 @@ def test_compile_and_run_correctness(
[12, 2],
[4, 10],
],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
@@ -747,6 +765,36 @@ def test_compile_and_run_correctness(
[0, 2],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: numpy.dot(x, 2),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([2, 7, 1],),
[4, 14, 2],
marks=pytest.mark.xfail(strict=True),
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
pytest.param(
lambda x: numpy.dot(2, x),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([2, 7, 1],),
[4, 14, 2],
marks=pytest.mark.xfail(strict=True),
),
],
)
def test_compile_and_run_tensor_correctness(
@@ -874,7 +922,7 @@ def test_compile_and_run_constant_dot_correctness(
default_compilation_configuration,
)
right_circuit = compile_numpy_function(
left,
right,
{"x": EncryptedTensor(Integer(64, False), shape)},
inputset,
default_compilation_configuration,