mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(mlir): re-write mlir conversion
This commit is contained in:
@@ -100,23 +100,6 @@ def value_is_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is encrypted and is of type unsigned Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is encrypted and is of type unsigned Integer
|
||||
"""
|
||||
|
||||
return (
|
||||
value_to_check.is_encrypted
|
||||
and isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
and not cast(Integer, value_to_check.dtype).is_signed
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is an encrypted TensorValue of type Integer.
|
||||
|
||||
@@ -129,22 +112,6 @@ def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
return value_is_tensor_integer(value_to_check) and value_to_check.is_encrypted
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is an encrypted TensorValue of type unsigned Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted TensorValue of type Integer and
|
||||
unsigned
|
||||
"""
|
||||
return (
|
||||
value_is_encrypted_tensor_integer(value_to_check)
|
||||
and not cast(Integer, value_to_check.dtype).is_signed
|
||||
)
|
||||
|
||||
|
||||
def value_is_clear_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is a clear TensorValue of type Integer.
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""MLIR conversion submodule."""
|
||||
from .converters import V0_OPSET_CONVERSION_FUNCTIONS
|
||||
from .mlir_converter import MLIRConverter
|
||||
"""MLIR conversion module."""
|
||||
|
||||
__all__ = ["MLIRConverter", "V0_OPSET_CONVERSION_FUNCTIONS"]
|
||||
from .graph_converter import OPGraphConverter
|
||||
|
||||
64
concrete/common/mlir/conversion_helpers.py
Normal file
64
concrete/common/mlir/conversion_helpers.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Helpers for MLIR conversion functionality."""
|
||||
|
||||
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from mlir.ir import Context, IntegerType, RankedTensorType, Type
|
||||
from zamalang.dialects.hlfhe import EncryptedIntegerType
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..values import BaseValue, TensorValue
|
||||
|
||||
# pylint: enable=no-name-in-module
|
||||
|
||||
|
||||
def integer_to_mlir_type(ctx: Context, integer: Integer, is_encrypted: bool) -> Optional[Type]:
|
||||
"""Convert an integer to its corresponding MLIR type.
|
||||
|
||||
Args:
|
||||
ctx (Context): the MLIR context to perform the conversion
|
||||
integer (Integer): the integer to convert
|
||||
is_encrypted (bool): whether the integer is encrypted or not
|
||||
|
||||
Returns:
|
||||
Type:
|
||||
the MLIR type corresponding to given integer and encryption status
|
||||
if it's supported otherwise None
|
||||
"""
|
||||
|
||||
bit_width = integer.bit_width
|
||||
|
||||
if is_encrypted:
|
||||
result = EncryptedIntegerType.get(ctx, bit_width)
|
||||
else:
|
||||
result = IntegerType.get_signless(bit_width)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def value_to_mlir_type(ctx: Context, value: BaseValue) -> Type:
|
||||
"""Convert a value to its corresponding MLIR type.
|
||||
|
||||
Args:
|
||||
ctx (Context): the MLIR context to perform the conversion
|
||||
value (BaseValue): the value to convert
|
||||
|
||||
Returns:
|
||||
Type: the MLIR type corresponding to given value
|
||||
"""
|
||||
|
||||
dtype = value.dtype
|
||||
if isinstance(dtype, Integer):
|
||||
try:
|
||||
mlir_type = integer_to_mlir_type(ctx, dtype, value.is_encrypted)
|
||||
if isinstance(value, TensorValue):
|
||||
if not value.is_scalar:
|
||||
mlir_type = RankedTensorType.get(value.shape, mlir_type)
|
||||
return mlir_type
|
||||
except ValueError:
|
||||
pass # the error below will be raised
|
||||
|
||||
raise TypeError(f"{value} is not supported for MLIR conversion")
|
||||
@@ -1,383 +0,0 @@
|
||||
"""Converter functions from the common IR to MLIR.
|
||||
|
||||
Converter functions all have the same signature `converter(node, preds, ir_to_mlir_node, ctx)`
|
||||
- `node`: IntermediateNode to be converted
|
||||
- `preds`: List of predecessors of `node` ordered as operands
|
||||
- `ir_to_mlir_node`: Dict mapping intermediate nodes to MLIR nodes or values
|
||||
- `ctx`: MLIR context
|
||||
"""
|
||||
from typing import cast
|
||||
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
import numpy
|
||||
from mlir.dialects import arith as arith_dialect
|
||||
from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
|
||||
from zamalang.dialects import hlfhe, hlfhelinalg
|
||||
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
value_is_scalar_integer,
|
||||
value_is_tensor_integer,
|
||||
)
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..representation.intermediate import Add, Constant, Dot, GenericFunction, Mul, Sub
|
||||
from ..values import TensorValue
|
||||
|
||||
|
||||
def _convert_scalar_constant_op_to_single_element_tensor_constant_op(operation):
|
||||
"""Convert a scalar constant operation result to a dense tensor constant operation result.
|
||||
|
||||
see https://github.com/zama-ai/concretefhe-internal/issues/837.
|
||||
|
||||
This is a temporary workaround before the compiler natively supports
|
||||
`tensor + scalar`, `tensor - scalar`, `tensor * scalar` operations.
|
||||
|
||||
Example input = `%c3_i4 = arith.constant 3 : i4`
|
||||
Example output = `%cst = arith.constant dense<3> : tensor<1xi4>`
|
||||
|
||||
Args:
|
||||
operation: operation to convert
|
||||
|
||||
Returns:
|
||||
the converted operation
|
||||
"""
|
||||
|
||||
operation_str = str(operation)
|
||||
|
||||
constant_start_location = operation_str.find("arith.constant") + len("arith.constant") + 1
|
||||
constant_end_location = operation_str.find(f": {str(operation.type)}") - 1
|
||||
constant_value = operation_str[constant_start_location:constant_end_location]
|
||||
|
||||
resulting_type = RankedTensorType.get((1,), operation.type)
|
||||
value_attr = Attribute.parse(f"dense<{constant_value}> : tensor<1x{str(operation.type)}>")
|
||||
|
||||
return arith_dialect.ConstantOp(resulting_type, value_attr).result
|
||||
|
||||
|
||||
def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert an addition intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "addition should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "addition should have a single output")
|
||||
|
||||
is_convertible = True
|
||||
one_of_the_inputs_is_a_tensor = False
|
||||
both_of_the_inputs_are_encrypted = True
|
||||
ordered_preds = preds
|
||||
|
||||
for input_ in node.inputs:
|
||||
if value_is_tensor_integer(input_):
|
||||
one_of_the_inputs_is_a_tensor = True
|
||||
elif not value_is_scalar_integer(input_):
|
||||
is_convertible = False
|
||||
|
||||
if not is_convertible:
|
||||
raise TypeError(
|
||||
f"Don't support addition between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
|
||||
if node.inputs[1].is_clear:
|
||||
both_of_the_inputs_are_encrypted = False
|
||||
if node.inputs[0].is_clear:
|
||||
both_of_the_inputs_are_encrypted = False
|
||||
ordered_preds = preds[::-1]
|
||||
|
||||
if one_of_the_inputs_is_a_tensor:
|
||||
if both_of_the_inputs_are_encrypted:
|
||||
return _linalg_add_eint_eint(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
return _linalg_add_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
|
||||
if both_of_the_inputs_are_encrypted:
|
||||
return _add_eint_eint(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
return _add_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
|
||||
|
||||
def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an addition intermediate node with (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.AddEintIntOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an addition intermediate node with (eint, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.AddEintOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
def _linalg_add_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an addition intermediate tensor node with (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
|
||||
if not str(rhs.type).startswith("tensor"):
|
||||
rhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(rhs)
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
|
||||
return hlfhelinalg.AddEintIntOp(vec_type, lhs, rhs).result
|
||||
|
||||
|
||||
def _linalg_add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert an addition intermediate tensor node with (eint, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
|
||||
return hlfhelinalg.AddEintOp(vec_type, lhs, rhs).result
|
||||
|
||||
|
||||
def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a subtraction intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "subtraction should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "subtraction should have a single output")
|
||||
|
||||
is_convertible = True
|
||||
one_of_the_inputs_is_a_tensor = False
|
||||
|
||||
if value_is_clear_tensor_integer(node.inputs[0]):
|
||||
one_of_the_inputs_is_a_tensor = True
|
||||
elif not value_is_clear_scalar_integer(node.inputs[0]):
|
||||
is_convertible = False
|
||||
|
||||
if value_is_tensor_integer(node.inputs[1]):
|
||||
one_of_the_inputs_is_a_tensor = True
|
||||
elif not value_is_scalar_integer(node.inputs[1]):
|
||||
is_convertible = False
|
||||
|
||||
if not is_convertible:
|
||||
raise TypeError(
|
||||
f"Don't support subtraction between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
|
||||
if one_of_the_inputs_is_a_tensor:
|
||||
return _linalg_sub_int_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
return _sub_int_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
|
||||
|
||||
def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert a subtraction intermediate node with (int, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.SubIntEintOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
def _linalg_sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert a subtraction intermediate node with (int, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
|
||||
if not str(lhs.type).startswith("tensor"):
|
||||
lhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(lhs)
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
|
||||
return hlfhelinalg.SubIntEintOp(vec_type, lhs, rhs).result
|
||||
|
||||
|
||||
def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a multiplication intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "multiplication should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "multiplication should have a single output")
|
||||
|
||||
is_convertible = True
|
||||
one_of_the_inputs_is_a_tensor = False
|
||||
ordered_preds = preds
|
||||
|
||||
for input_ in node.inputs:
|
||||
if value_is_tensor_integer(input_):
|
||||
one_of_the_inputs_is_a_tensor = True
|
||||
elif not value_is_scalar_integer(input_):
|
||||
is_convertible = False
|
||||
|
||||
if not is_convertible:
|
||||
raise TypeError(
|
||||
f"Don't support multiplication between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
|
||||
if node.inputs[0].is_clear:
|
||||
ordered_preds = preds[::-1]
|
||||
|
||||
if one_of_the_inputs_is_a_tensor:
|
||||
return _linalg_mul_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
return _mul_eint_int(node, ordered_preds, ir_to_mlir_node, ctx)
|
||||
|
||||
|
||||
def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert a multiplication intermediate node with (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.MulEintIntOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
def _linalg_mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Convert a subtraction intermediate node with (int, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
|
||||
if not str(rhs.type).startswith("tensor"):
|
||||
rhs = _convert_scalar_constant_op_to_single_element_tensor_constant_op(rhs)
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width)
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
|
||||
return hlfhelinalg.MulEintIntOp(vec_type, lhs, rhs).result
|
||||
|
||||
|
||||
def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a constant input."""
|
||||
value = node.outputs[0]
|
||||
|
||||
if value_is_clear_scalar_integer(value):
|
||||
value = cast(TensorValue, value)
|
||||
|
||||
dtype = cast(Integer, value.dtype)
|
||||
data = node.constant_data
|
||||
|
||||
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
|
||||
return arith_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, data)).result
|
||||
|
||||
if value_is_clear_tensor_integer(value):
|
||||
value = cast(TensorValue, value)
|
||||
|
||||
dtype = cast(Integer, value.dtype)
|
||||
data = node.constant_data
|
||||
|
||||
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
|
||||
vec_type = RankedTensorType.get(value.shape, int_type)
|
||||
|
||||
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
|
||||
# provided by LLVM
|
||||
|
||||
# `DenseElementsAttr` should have been used instead but it's impossible to assign
|
||||
# custom bit-widths using it (e.g., uint5)
|
||||
|
||||
# since we coudn't create a `DenseElementsAttr` with a custom bit width using python api
|
||||
# we use `Attribute.parse` to let the underlying library do it by itself
|
||||
|
||||
value_attr = Attribute.parse(f"dense<{str(data.tolist())}> : {vec_type}")
|
||||
return arith_dialect.ConstantOp(vec_type, value_attr).result
|
||||
|
||||
raise TypeError(f"Don't support {value} constants")
|
||||
|
||||
|
||||
def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
|
||||
"""Convert a GenericFunction intermediate node."""
|
||||
|
||||
variable_input_indices = [
|
||||
idx for idx, pred in enumerate(preds) if not isinstance(pred, Constant)
|
||||
]
|
||||
|
||||
assert_true(
|
||||
(non_constant_pred_count := len(variable_input_indices)) == 1,
|
||||
f"LUT should have a single variable input (got {non_constant_pred_count})",
|
||||
)
|
||||
|
||||
variable_input_idx = variable_input_indices[0]
|
||||
variable_input_value = node.inputs[variable_input_idx]
|
||||
|
||||
assert_true(len(node.outputs) == 1, "LUT should have a single output")
|
||||
if not value_is_encrypted_unsigned_integer(variable_input_value):
|
||||
raise TypeError(
|
||||
f"Only support LUT with encrypted unsigned integers inputs "
|
||||
f"(but {variable_input_value} is provided)"
|
||||
)
|
||||
if not value_is_encrypted_unsigned_integer(node.outputs[0]):
|
||||
raise TypeError(
|
||||
f"Only support LUT with encrypted unsigned integers outputs "
|
||||
f"(but {node.outputs[0]} is provided)"
|
||||
)
|
||||
|
||||
x_node = preds[variable_input_idx]
|
||||
x = ir_to_mlir_node[x_node]
|
||||
tables = additional_conversion_info["tables"][node]
|
||||
|
||||
# TODO: #559 adapt the code to support multi TLUs
|
||||
# This cannot be reached today as compilation fails if the intermediate values are not all
|
||||
# scalars
|
||||
if len(tables) > 1: # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"MLIR conversion currently does not support multiple test vectors for LUT"
|
||||
)
|
||||
|
||||
table = tables[0][0]
|
||||
|
||||
out_dtype = cast(Integer, node.outputs[0].dtype)
|
||||
# Create table
|
||||
dense_elem = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=ctx)
|
||||
tensor_lut = arith_dialect.ConstantOp(
|
||||
RankedTensorType.get([len(table)], IntegerType.get_signless(64, context=ctx)),
|
||||
dense_elem,
|
||||
).result
|
||||
|
||||
int_type = hlfhe.EncryptedIntegerType.get(ctx, out_dtype.bit_width)
|
||||
|
||||
if value_is_encrypted_tensor_integer(node.inputs[0]):
|
||||
vec_type = RankedTensorType.get(node.outputs[0].shape, int_type)
|
||||
return hlfhelinalg.ApplyLookupTableEintOp(vec_type, x, tensor_lut).result
|
||||
return hlfhe.ApplyLookupTableEintOp(int_type, x, tensor_lut).result
|
||||
|
||||
|
||||
def dot(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a dot intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "Dot should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "Dot should have a single output")
|
||||
if not (
|
||||
(
|
||||
value_is_encrypted_tensor_integer(node.inputs[0])
|
||||
and value_is_clear_tensor_integer(node.inputs[1])
|
||||
)
|
||||
or (
|
||||
value_is_encrypted_tensor_integer(node.inputs[1])
|
||||
and value_is_clear_tensor_integer(node.inputs[0])
|
||||
)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Don't support dot between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
)
|
||||
lhs_node, rhs_node = preds
|
||||
# need to flip as underlying operation need encrypted first
|
||||
if value_is_clear_tensor_integer(node.inputs[0]):
|
||||
lhs_node, rhs_node = rhs_node, lhs_node
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhelinalg.Dot(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].dtype.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
V0_OPSET_CONVERSION_FUNCTIONS = {
|
||||
Add: add,
|
||||
Sub: sub,
|
||||
Mul: mul,
|
||||
Constant: constant,
|
||||
GenericFunction: apply_lut,
|
||||
Dot: dot,
|
||||
}
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
79
concrete/common/mlir/graph_converter.py
Normal file
79
concrete/common/mlir/graph_converter.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Module that provides OPGraph conversion functionality."""
|
||||
|
||||
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
import networkx as nx
|
||||
import zamalang
|
||||
from mlir.dialects import builtin
|
||||
from mlir.ir import Context, InsertionPoint, Location, Module
|
||||
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Input
|
||||
from .conversion_helpers import value_to_mlir_type
|
||||
from .node_converter import IntermediateNodeConverter
|
||||
|
||||
# pylint: enable=no-name-in-module
|
||||
|
||||
|
||||
class OPGraphConverter(ABC):
|
||||
"""Converter of OPGraph to MLIR."""
|
||||
|
||||
def convert(self, opgraph: OPGraph) -> str:
|
||||
"""Convert an operation graph to its corresponding MLIR representation.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph): the operation graph to be converted
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to opgraph
|
||||
"""
|
||||
|
||||
additional_conversion_info = self._generate_additional_info_dict(opgraph)
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
zamalang.register_dialects(ctx)
|
||||
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
parameters = [
|
||||
value_to_mlir_type(ctx, input_node.outputs[0])
|
||||
for input_node in opgraph.get_ordered_inputs()
|
||||
]
|
||||
|
||||
@builtin.FuncOp.from_py_func(*parameters)
|
||||
def main(*arg):
|
||||
ir_to_mlir = {}
|
||||
for arg_num, node in opgraph.input_nodes.items():
|
||||
ir_to_mlir[node] = arg[arg_num]
|
||||
|
||||
for node in nx.topological_sort(opgraph.graph):
|
||||
if isinstance(node, Input):
|
||||
continue
|
||||
|
||||
preds = [ir_to_mlir[pred] for pred in opgraph.get_ordered_preds(node)]
|
||||
node_converter = IntermediateNodeConverter(ctx, opgraph, node, preds)
|
||||
ir_to_mlir[node] = node_converter.convert(additional_conversion_info)
|
||||
|
||||
results = (
|
||||
ir_to_mlir[output_node] for output_node in opgraph.get_ordered_outputs()
|
||||
)
|
||||
return results
|
||||
|
||||
return str(module)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]:
|
||||
"""Generate additional conversion info dict for the MLIR converter.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph): the operation graph from which the additional info will be generated
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: dict of additional conversion info
|
||||
"""
|
||||
@@ -1,194 +0,0 @@
|
||||
"""File containing code to convert a DAG containing ir nodes to the compiler opset."""
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Tuple, cast
|
||||
|
||||
import networkx as nx
|
||||
import zamalang
|
||||
from mlir.dialects import builtin
|
||||
from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module, RankedTensorType
|
||||
from mlir.ir import Type as MLIRType
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from .. import values
|
||||
from ..data_types import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_unsigned_integer,
|
||||
value_is_encrypted_tensor_unsigned_integer,
|
||||
)
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Input
|
||||
|
||||
|
||||
class MLIRConverter(ABC):
|
||||
"""Converter of the common IR to MLIR."""
|
||||
|
||||
def __init__(self, conversion_functions: dict) -> None:
|
||||
"""Instantiate a converter with a given set of converters.
|
||||
|
||||
Args:
|
||||
conversion_functions (dict): mapping common IR nodes to functions that generate MLIR.
|
||||
every function should have 4 arguments:
|
||||
- node (IntermediateNode): the node itself to be converted
|
||||
- operands (IntermediateNode): predecessors of node ordered as operands
|
||||
- ir_to_mlir_node (dict): mapping between IntermediateNode and their equivalent
|
||||
MLIR values
|
||||
- context (mlir.Context): the MLIR context being used for the conversion
|
||||
"""
|
||||
self.conversion_functions = conversion_functions
|
||||
self._init_context()
|
||||
|
||||
def _init_context(self):
|
||||
self.context = Context()
|
||||
zamalang.register_dialects(self.context)
|
||||
|
||||
def _get_tensor_type(
|
||||
self,
|
||||
bit_width: int,
|
||||
is_encrypted: bool,
|
||||
is_signed: bool,
|
||||
shape: Tuple[int, ...],
|
||||
) -> MLIRType:
|
||||
"""Get the MLIRType for a tensor element given its properties.
|
||||
|
||||
Args:
|
||||
bit_width (int): number of bits used for the scalar
|
||||
is_encrypted (bool): is the scalar encrypted or not
|
||||
is_signed (bool): is the scalar signed or not
|
||||
shape (Tuple[int, ...]): shape of the tensor
|
||||
|
||||
Returns:
|
||||
MLIRType: corresponding MLIR type
|
||||
"""
|
||||
element_type = self._get_scalar_integer_type(bit_width, is_encrypted, is_signed)
|
||||
return RankedTensorType.get(shape, element_type)
|
||||
|
||||
def _get_scalar_integer_type(
|
||||
self, bit_width: int, is_encrypted: bool, is_signed: bool
|
||||
) -> MLIRType:
|
||||
"""Get the MLIRType for a scalar element given its properties.
|
||||
|
||||
Args:
|
||||
bit_width (int): number of bits used for the scalar
|
||||
is_encrypted (bool): is the scalar encrypted or not
|
||||
is_signed (bool): is the scalar signed or not
|
||||
|
||||
Returns:
|
||||
MLIRType: corresponding MLIR type
|
||||
"""
|
||||
if is_encrypted and not is_signed:
|
||||
return hlfhe.EncryptedIntegerType.get(self.context, bit_width)
|
||||
if is_signed and not is_encrypted: # clear signed
|
||||
return IntegerType.get_signed(bit_width)
|
||||
# should be clear unsigned at this point
|
||||
assert_true(not is_signed and not is_encrypted)
|
||||
# unsigned integer are considered signless in the compiler
|
||||
return IntegerType.get_signless(bit_width)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
|
||||
"""Generate the additional_conversion_info dict for the MLIR converter.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph for which we need the conversion infos.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The dict with the additional conversion infos.
|
||||
"""
|
||||
|
||||
def common_value_to_mlir_type(self, value: values.BaseValue) -> MLIRType:
|
||||
"""Convert a common value to its corresponding MLIR Type.
|
||||
|
||||
Args:
|
||||
value: value to convert
|
||||
|
||||
Returns:
|
||||
corresponding MLIR type
|
||||
"""
|
||||
if value_is_encrypted_scalar_unsigned_integer(value):
|
||||
return self._get_scalar_integer_type(cast(Integer, value.dtype).bit_width, True, False)
|
||||
if value_is_clear_scalar_integer(value):
|
||||
dtype = cast(Integer, value.dtype)
|
||||
return self._get_scalar_integer_type(
|
||||
dtype.bit_width, is_encrypted=False, is_signed=dtype.is_signed
|
||||
)
|
||||
if value_is_encrypted_tensor_unsigned_integer(value):
|
||||
dtype = cast(Integer, value.dtype)
|
||||
return self._get_tensor_type(
|
||||
dtype.bit_width,
|
||||
is_encrypted=True,
|
||||
is_signed=False,
|
||||
shape=cast(values.TensorValue, value).shape,
|
||||
)
|
||||
if value_is_clear_tensor_integer(value):
|
||||
dtype = cast(Integer, value.dtype)
|
||||
return self._get_tensor_type(
|
||||
dtype.bit_width,
|
||||
is_encrypted=False,
|
||||
is_signed=dtype.is_signed,
|
||||
shape=cast(values.TensorValue, value).shape,
|
||||
)
|
||||
raise TypeError(f"can't convert value of type {type(value)} to MLIR type")
|
||||
|
||||
def convert(self, op_graph: OPGraph) -> str:
|
||||
"""Convert the graph of IntermediateNode to an MLIR textual representation.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): graph of IntermediateNode to be converted
|
||||
|
||||
Raises:
|
||||
NotImplementedError: raised if an unknown node type is encountered.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation
|
||||
"""
|
||||
additional_conversion_info = self._generate_additional_info_dict(op_graph)
|
||||
|
||||
with self.context, Location.unknown():
|
||||
module = Module.create()
|
||||
# collect inputs
|
||||
with InsertionPoint(module.body):
|
||||
func_types = [
|
||||
self.common_value_to_mlir_type(input_node.inputs[0])
|
||||
for input_node in op_graph.get_ordered_inputs()
|
||||
]
|
||||
|
||||
@builtin.FuncOp.from_py_func(*func_types)
|
||||
def main(*arg):
|
||||
ir_to_mlir_node = {}
|
||||
for arg_num, node in op_graph.input_nodes.items():
|
||||
ir_to_mlir_node[node] = arg[arg_num]
|
||||
for node in nx.topological_sort(op_graph.graph):
|
||||
if isinstance(node, Input):
|
||||
continue
|
||||
mlir_op = self.conversion_functions.get(type(node), None)
|
||||
if mlir_op is None: # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
f"we don't yet support conversion to MLIR of computations using"
|
||||
f"{type(node)}"
|
||||
)
|
||||
preds = op_graph.get_ordered_preds(node)
|
||||
# convert to mlir
|
||||
result = mlir_op(
|
||||
node,
|
||||
preds,
|
||||
ir_to_mlir_node,
|
||||
self.context,
|
||||
additional_conversion_info,
|
||||
)
|
||||
ir_to_mlir_node[node] = result
|
||||
|
||||
results = (
|
||||
ir_to_mlir_node[output_node]
|
||||
for output_node in op_graph.get_ordered_outputs()
|
||||
)
|
||||
return results
|
||||
|
||||
return module.__str__()
|
||||
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
325
concrete/common/mlir/node_converter.py
Normal file
325
concrete/common/mlir/node_converter.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""Module that provides IntermediateNode conversion functionality."""
|
||||
|
||||
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
from typing import Any, Dict, List, cast
|
||||
|
||||
import numpy
|
||||
from mlir.dialects import arith
|
||||
from mlir.ir import (
|
||||
Attribute,
|
||||
Context,
|
||||
DenseElementsAttr,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
OpResult,
|
||||
RankedTensorType,
|
||||
)
|
||||
from zamalang.dialects import hlfhe, hlfhelinalg
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..debugging import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import (
|
||||
Add,
|
||||
Constant,
|
||||
Dot,
|
||||
GenericFunction,
|
||||
IntermediateNode,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
from ..values import TensorValue
|
||||
from .conversion_helpers import value_to_mlir_type
|
||||
|
||||
# pylint: enable=no-name-in-module
|
||||
|
||||
|
||||
class IntermediateNodeConverter:
|
||||
"""Converter of IntermediateNode to MLIR."""
|
||||
|
||||
ctx: Context
|
||||
opgraph: OPGraph
|
||||
node: IntermediateNode
|
||||
preds: List[OpResult]
|
||||
|
||||
all_of_the_inputs_are_encrypted: bool
|
||||
all_of_the_inputs_are_tensors: bool
|
||||
one_of_the_inputs_is_a_tensor: bool
|
||||
|
||||
def __init__(
|
||||
self, ctx: Context, opgraph: OPGraph, node: IntermediateNode, preds: List[OpResult]
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.opgraph = opgraph
|
||||
self.node = node
|
||||
self.preds = preds
|
||||
|
||||
self.all_of_the_inputs_are_encrypted = True
|
||||
self.all_of_the_inputs_are_tensors = True
|
||||
self.one_of_the_inputs_is_a_tensor = False
|
||||
|
||||
for inp in node.inputs:
|
||||
if inp.is_clear:
|
||||
self.all_of_the_inputs_are_encrypted = False
|
||||
|
||||
if isinstance(inp, TensorValue):
|
||||
if inp.is_scalar:
|
||||
self.all_of_the_inputs_are_tensors = False
|
||||
else:
|
||||
self.one_of_the_inputs_is_a_tensor = True
|
||||
else: # pragma: no cover
|
||||
# this branch is not covered as there are only TensorValues for now
|
||||
self.all_of_the_inputs_are_tensors = False
|
||||
|
||||
def convert(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
|
||||
"""Convert an intermediate node to its corresponding MLIR representation.
|
||||
|
||||
Args:
|
||||
additional_conversion_info (Dict[str, Any]):
|
||||
external info that the converted node might need
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
if isinstance(self.node, Add):
|
||||
return self.convert_add()
|
||||
|
||||
if isinstance(self.node, Constant):
|
||||
return self.convert_constant()
|
||||
|
||||
if isinstance(self.node, Dot):
|
||||
return self.convert_dot()
|
||||
|
||||
if isinstance(self.node, GenericFunction):
|
||||
return self.convert_generic_function(additional_conversion_info)
|
||||
|
||||
if isinstance(self.node, Mul):
|
||||
return self.convert_mul()
|
||||
|
||||
if isinstance(self.node, Sub):
|
||||
return self.convert_sub()
|
||||
|
||||
# this statement is not covered as unsupported opeations fail on check mlir compatibility
|
||||
raise NotImplementedError(
|
||||
f"{type(self.node)} nodes cannot be converted to MLIR yet"
|
||||
) # pragma: no cover
|
||||
|
||||
def convert_add(self) -> OpResult:
|
||||
"""Convert an Add node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = hlfhelinalg.AddEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = hlfhe.AddEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
if self.node.inputs[0].is_clear: # pragma: no cover
|
||||
# this branch is not covered as it's impossible to get into due to how tracing works
|
||||
# however, it doesn't hurt to keep it as an extra measure
|
||||
preds = preds[::-1]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = hlfhelinalg.AddEintIntOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = hlfhe.AddEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_constant(self) -> OpResult:
|
||||
"""Convert a Constant node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 0)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
value = self.node.outputs[0]
|
||||
if not isinstance(value, TensorValue): # pragma: no cover
|
||||
# this branch is not covered as there are only TensorValues for now
|
||||
raise NotImplementedError(f"{value} constants cannot be converted to MLIR yet")
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, value)
|
||||
data = cast(Constant, self.node).constant_data
|
||||
|
||||
if value.is_scalar:
|
||||
attr = IntegerAttr.get(resulting_type, data)
|
||||
else:
|
||||
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
|
||||
# provided by LLVM
|
||||
|
||||
# what should have been used is `DenseElementsAttr` but it's impossible to assign
|
||||
# custom bit-widths using it (e.g., uint5)
|
||||
|
||||
# since we coudn't create a `DenseElementsAttr` with a custom bit width using python api
|
||||
# we use `Attribute.parse` to let the underlying library do it by itself
|
||||
|
||||
attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}")
|
||||
|
||||
return arith.ConstantOp(resulting_type, attr).result
|
||||
|
||||
def convert_dot(self) -> OpResult:
|
||||
"""Convert a Dot node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
lhs = self.node.inputs[0]
|
||||
rhs = self.node.inputs[1]
|
||||
raise NotImplementedError(
|
||||
f"Dot product between {lhs} and {rhs} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
if self.node.inputs[0].is_clear:
|
||||
preds = preds[::-1]
|
||||
|
||||
if self.all_of_the_inputs_are_tensors:
|
||||
# numpy.dot(x, y) where x and y are both vectors = regular dot product
|
||||
result = hlfhelinalg.Dot(resulting_type, *preds).result
|
||||
|
||||
elif not self.one_of_the_inputs_is_a_tensor:
|
||||
# numpy.dot(x, y) where x and y are both scalars = x * y
|
||||
result = hlfhe.MulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
else:
|
||||
# numpy.dot(x, y) where one of x or y is a scalar and the other one is a vector = x * y
|
||||
result = hlfhelinalg.MulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_generic_function(self, additional_conversion_info: Dict[str, Any]) -> OpResult:
|
||||
"""Convert a GenericFunction node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
variable_input_indices = [
|
||||
idx
|
||||
for idx, inp in enumerate(self.opgraph.get_ordered_preds(self.node))
|
||||
if not isinstance(inp, Constant)
|
||||
]
|
||||
if len(variable_input_indices) != 1: # pragma: no cover
|
||||
# this branch is not covered as it's impossible to get into due to how tracing works
|
||||
# however, it doesn't hurt to keep it as an extra measure
|
||||
raise NotImplementedError(
|
||||
"Table lookups with more than one variable input cannot be converted to MLIR yet"
|
||||
)
|
||||
variable_input_index = variable_input_indices[0]
|
||||
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
value = self.node.inputs[variable_input_index]
|
||||
assert_true(value.is_encrypted)
|
||||
|
||||
if not isinstance(value.dtype, Integer) or value.dtype.is_signed: # pragma: no cover
|
||||
# this branch is not covered as it's impossible to get into due to how compilation works
|
||||
# however, it doesn't hurt to keep it as an extra measure
|
||||
raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet")
|
||||
|
||||
tables = additional_conversion_info["tables"][self.node]
|
||||
|
||||
# TODO: #559 adapt the code to support multi TLUs
|
||||
# This cannot be reached today as compilation fails
|
||||
# if the intermediate values are not all scalars
|
||||
if len(tables) > 1: # pragma: no cover
|
||||
raise NotImplementedError("Multi table lookups cannot be converted to MLIR yet")
|
||||
|
||||
table = tables[0][0]
|
||||
|
||||
lut_size = len(table)
|
||||
lut_type = RankedTensorType.get([lut_size], IntegerType.get_signless(64, context=self.ctx))
|
||||
lut_attr = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=self.ctx)
|
||||
lut = arith.ConstantOp(lut_type, lut_attr).result
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
pred = self.preds[variable_input_index]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = hlfhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
else:
|
||||
result = hlfhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_mul(self) -> OpResult:
|
||||
"""Convert a Mul node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
lhs = self.node.inputs[0]
|
||||
rhs = self.node.inputs[1]
|
||||
raise NotImplementedError(
|
||||
f"Multiplication between {lhs} and {rhs} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
if self.node.inputs[0].is_clear: # pragma: no cover
|
||||
# this branch is not covered as it's impossible to get into due to how tracing works
|
||||
# however, it doesn't hurt to keep it as an extra measure
|
||||
preds = preds[::-1]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = hlfhelinalg.MulEintIntOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = hlfhe.MulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
def convert_sub(self) -> OpResult:
|
||||
"""Convert a Sub node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
lhs = self.node.inputs[0]
|
||||
rhs = self.node.inputs[1]
|
||||
if not (lhs.is_clear and rhs.is_encrypted):
|
||||
raise NotImplementedError(
|
||||
f"Subtraction of {rhs} from {lhs} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = hlfhelinalg.SubIntEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = hlfhe.SubIntEintOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
@@ -475,19 +475,20 @@ class Dot(IntermediateNode):
|
||||
|
||||
assert_true(
|
||||
all(
|
||||
isinstance(input_value, TensorValue) and input_value.ndim == 1
|
||||
isinstance(input_value, TensorValue) and input_value.ndim <= 1
|
||||
for input_value in self.inputs
|
||||
),
|
||||
f"Dot only supports two vectors ({TensorValue.__name__} with ndim == 1)",
|
||||
f"Dot only supports two scalars or vectors ({TensorValue.__name__} with ndim up to 1)",
|
||||
)
|
||||
|
||||
lhs = cast(TensorValue, self.inputs[0])
|
||||
rhs = cast(TensorValue, self.inputs[1])
|
||||
|
||||
assert_true(
|
||||
lhs.shape[0] == rhs.shape[0],
|
||||
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
|
||||
)
|
||||
if lhs.ndim == 1 and rhs.ndim == 1:
|
||||
assert_true(
|
||||
lhs.shape[0] == rhs.shape[0],
|
||||
f"Dot between vectors of shapes {lhs.shape} and {rhs.shape} is not supported",
|
||||
)
|
||||
|
||||
output_scalar_value = (
|
||||
EncryptedScalar if (lhs.is_encrypted or rhs.is_encrypted) else ClearScalar
|
||||
|
||||
@@ -14,7 +14,6 @@ from ..common.data_types import Integer
|
||||
from ..common.debugging import format_operation_graph
|
||||
from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
|
||||
from ..common.mlir.utils import (
|
||||
check_graph_values_compatibility_with_mlir,
|
||||
extend_direct_lookup_tables,
|
||||
@@ -312,7 +311,7 @@ def _compile_numpy_function_internal(
|
||||
prepare_op_graph_for_mlir(op_graph)
|
||||
|
||||
# Convert graph to an MLIR representation
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
converter = NPMLIRConverter()
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
# Show MLIR representation if requested
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, DefaultDict, Dict, List, Tuple
|
||||
import numpy
|
||||
|
||||
from ..common.debugging import assert_true
|
||||
from ..common.mlir.mlir_converter import MLIRConverter
|
||||
from ..common.mlir.graph_converter import OPGraphConverter
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import GenericFunction, IntermediateNode
|
||||
|
||||
@@ -67,27 +67,18 @@ def generate_deduplicated_tables(
|
||||
)
|
||||
|
||||
|
||||
class NPMLIRConverter(MLIRConverter):
|
||||
class NPMLIRConverter(OPGraphConverter):
|
||||
"""Numpy-specific MLIR converter."""
|
||||
|
||||
@staticmethod
|
||||
def _generate_additional_info_dict(op_graph: OPGraph) -> Dict[str, Any]:
|
||||
"""Generate the additional_conversion_info dict for the MLIR converter.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the OPGraph for which we need the conversion infos.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The dict with the additional conversion infos.
|
||||
"""
|
||||
|
||||
def _generate_additional_info_dict(opgraph: OPGraph) -> Dict[str, Any]:
|
||||
additional_conversion_info = {}
|
||||
|
||||
# Disable numpy warnings during conversion to avoid issues during TLU generation
|
||||
with numpy.errstate(all="ignore"):
|
||||
additional_conversion_info["tables"] = {
|
||||
node: generate_deduplicated_tables(node, op_graph.get_ordered_preds(node))
|
||||
for node in op_graph.graph.nodes()
|
||||
node: generate_deduplicated_tables(node, opgraph.get_ordered_preds(node))
|
||||
for node in opgraph.graph.nodes()
|
||||
if isinstance(node, GenericFunction)
|
||||
}
|
||||
|
||||
|
||||
115
tests/common/mlir/test_conversion_helpers.py
Normal file
115
tests/common/mlir/test_conversion_helpers.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Test file for MLIR conversion helpers."""
|
||||
|
||||
# pylint cannot extract symbol information of 'mlir' module so we need to disable some lints
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
import pytest
|
||||
import zamalang
|
||||
from mlir.ir import Context, Location
|
||||
|
||||
from concrete.common.data_types import Float, SignedInteger, UnsignedInteger
|
||||
from concrete.common.mlir.conversion_helpers import integer_to_mlir_type, value_to_mlir_type
|
||||
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
|
||||
# pylint: enable=no-name-in-module
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"integer,is_encrypted,expected_mlir_type_str",
|
||||
[
|
||||
pytest.param(SignedInteger(5), False, "i5"),
|
||||
pytest.param(UnsignedInteger(5), False, "i5"),
|
||||
pytest.param(SignedInteger(32), False, "i32"),
|
||||
pytest.param(UnsignedInteger(32), False, "i32"),
|
||||
pytest.param(SignedInteger(5), True, "!HLFHE.eint<5>"),
|
||||
pytest.param(UnsignedInteger(5), True, "!HLFHE.eint<5>"),
|
||||
],
|
||||
)
|
||||
def test_integer_to_mlir_type(integer, is_encrypted, expected_mlir_type_str):
|
||||
"""Test function for integer to MLIR type conversion."""
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
zamalang.register_dialects(ctx)
|
||||
assert str(integer_to_mlir_type(ctx, integer, is_encrypted)) == expected_mlir_type_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"integer,is_encrypted,expected_error_message",
|
||||
[
|
||||
pytest.param(SignedInteger(32), True, "can't create eint with the given width"),
|
||||
pytest.param(UnsignedInteger(32), True, "can't create eint with the given width"),
|
||||
],
|
||||
)
|
||||
def test_fail_integer_to_mlir_type(integer, is_encrypted, expected_error_message):
|
||||
"""Test function for failed integer to MLIR type conversion."""
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
with Context() as ctx, Location.unknown():
|
||||
zamalang.register_dialects(ctx)
|
||||
integer_to_mlir_type(ctx, integer, is_encrypted)
|
||||
|
||||
assert str(excinfo.value) == expected_error_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected_mlir_type_str",
|
||||
[
|
||||
pytest.param(ClearScalar(SignedInteger(5)), "i5"),
|
||||
pytest.param(ClearTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
|
||||
pytest.param(EncryptedScalar(SignedInteger(5)), "!HLFHE.eint<5>"),
|
||||
pytest.param(EncryptedTensor(SignedInteger(5), shape=(2, 3)), "tensor<2x3x!HLFHE.eint<5>>"),
|
||||
pytest.param(ClearScalar(UnsignedInteger(5)), "i5"),
|
||||
pytest.param(ClearTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3xi5>"),
|
||||
pytest.param(EncryptedScalar(UnsignedInteger(5)), "!HLFHE.eint<5>"),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(5), shape=(2, 3)), "tensor<2x3x!HLFHE.eint<5>>"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_value_to_mlir_type(value, expected_mlir_type_str):
|
||||
"""Test function for value to MLIR type conversion."""
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
zamalang.register_dialects(ctx)
|
||||
assert str(value_to_mlir_type(ctx, value)) == expected_mlir_type_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected_error_message",
|
||||
[
|
||||
pytest.param(
|
||||
ClearScalar(Float(32)),
|
||||
"ClearScalar<float32> is not supported for MLIR conversion",
|
||||
),
|
||||
pytest.param(
|
||||
ClearTensor(Float(32), shape=(2, 3)),
|
||||
"ClearTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedScalar(Float(32)),
|
||||
"EncryptedScalar<float32> is not supported for MLIR conversion",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(Float(32), shape=(2, 3)),
|
||||
"EncryptedTensor<float32, shape=(2, 3)> is not supported for MLIR conversion",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedScalar(UnsignedInteger(32)),
|
||||
"EncryptedScalar<uint32> is not supported for MLIR conversion",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(32), shape=(2, 3)),
|
||||
"EncryptedTensor<uint32, shape=(2, 3)> is not supported for MLIR conversion",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_value_to_mlir_type(value, expected_error_message):
|
||||
"""Test function for failed value to MLIR type conversion."""
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
with Context() as ctx, Location.unknown():
|
||||
zamalang.register_dialects(ctx)
|
||||
value_to_mlir_type(ctx, value)
|
||||
|
||||
assert str(excinfo.value) == expected_error_message
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Test converter functions"""
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.mlir.converters import add, apply_lut, constant, dot, mul, sub
|
||||
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar
|
||||
|
||||
|
||||
class MockNode:
|
||||
"""Mocking an intermediate node"""
|
||||
|
||||
def __init__(self, inputs_n=5, outputs_n=5, inputs=None, outputs=None):
|
||||
if inputs is None:
|
||||
self.inputs = [None for i in range(inputs_n)]
|
||||
else:
|
||||
self.inputs = inputs
|
||||
if outputs is None:
|
||||
self.outputs = [None for i in range(outputs_n)]
|
||||
else:
|
||||
self.outputs = outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("converter", [add, sub, mul, dot])
|
||||
def test_failing_converter(converter):
|
||||
"""Test failing converter"""
|
||||
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
|
||||
converter(MockNode(2, 1), None, None, None)
|
||||
|
||||
|
||||
def test_fail_non_integer_const():
|
||||
"""Test failing constant converter with non-integer"""
|
||||
with pytest.raises(TypeError, match=r"Don't support .* constants"):
|
||||
constant(MockNode(outputs=[ClearScalar(Float(32))]), None, None, None)
|
||||
|
||||
with pytest.raises(TypeError, match=r"Don't support .* constants"):
|
||||
constant(MockNode(outputs=[ClearTensor(Float(32), shape=(2,))]), None, None, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_node",
|
||||
[
|
||||
ClearScalar(Integer(8, True)),
|
||||
ClearScalar(Integer(8, False)),
|
||||
EncryptedScalar(Integer(8, True)),
|
||||
],
|
||||
)
|
||||
def test_fail_tlu_input(input_node):
|
||||
"""Test failing LUT converter with invalid input"""
|
||||
with pytest.raises(
|
||||
TypeError, match=r"Only support LUT with encrypted unsigned integers inputs"
|
||||
):
|
||||
apply_lut(
|
||||
MockNode(inputs=[input_node], outputs=[EncryptedScalar(Integer(8, False))]),
|
||||
[None],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_node",
|
||||
[
|
||||
ClearScalar(Integer(8, True)),
|
||||
ClearScalar(Integer(8, False)),
|
||||
EncryptedScalar(Integer(8, True)),
|
||||
],
|
||||
)
|
||||
def test_fail_tlu_output(input_node):
|
||||
"""Test failing LUT converter with invalid output"""
|
||||
with pytest.raises(
|
||||
TypeError, match=r"Only support LUT with encrypted unsigned integers outputs"
|
||||
):
|
||||
apply_lut(
|
||||
MockNode(inputs=[EncryptedScalar(Integer(8, False))], outputs=[input_node]),
|
||||
[None],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
@@ -1,392 +0,0 @@
|
||||
"""Test file for conversion to MLIR"""
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
import itertools
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
from mlir.ir import IntegerType, Location, RankedTensorType, UnrankedTensorType
|
||||
from zamalang import compiler
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
|
||||
from concrete.common.values import ClearScalar, EncryptedScalar
|
||||
from concrete.common.values.tensors import ClearTensor, EncryptedTensor
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph, prepare_op_graph_for_mlir
|
||||
from concrete.numpy.np_mlir_converter import NPMLIRConverter
|
||||
|
||||
|
||||
def add(x, y):
|
||||
"""Test simple add"""
|
||||
return x + y
|
||||
|
||||
|
||||
def constant_add(x):
|
||||
"""Test constant add"""
|
||||
return x + 5
|
||||
|
||||
|
||||
def sub(x, y):
|
||||
"""Test simple sub"""
|
||||
return x - y
|
||||
|
||||
|
||||
def constant_sub(x):
|
||||
"""Test constant sub"""
|
||||
return 12 - x
|
||||
|
||||
|
||||
def mul(x, y):
|
||||
"""Test simple mul"""
|
||||
return x * y
|
||||
|
||||
|
||||
def constant_mul(x):
|
||||
"""Test constant mul"""
|
||||
return x * 2
|
||||
|
||||
|
||||
def sub_add_mul(x, y, z):
|
||||
"""Test combination of ops"""
|
||||
return z - y + x * z
|
||||
|
||||
|
||||
def ret_multiple(x, y, z):
|
||||
"""Test return of multiple values"""
|
||||
return x, y, z
|
||||
|
||||
|
||||
def ret_multiple_different_order(x, y, z):
|
||||
"""Test return of multiple values in a different order from input"""
|
||||
return y, z, x
|
||||
|
||||
|
||||
def lut(x):
|
||||
"""Test lookup table"""
|
||||
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
|
||||
return table[x]
|
||||
|
||||
|
||||
# TODO: remove workaround #359
|
||||
def lut_more_bits_than_table_length(x, y):
|
||||
"""Test lookup table when bit_width support longer LUT"""
|
||||
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
|
||||
return table[x] + y
|
||||
|
||||
|
||||
# TODO: remove workaround #359
|
||||
def lut_less_bits_than_table_length(x):
|
||||
"""Test lookup table when bit_width support smaller LUT"""
|
||||
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7, 3, 6, 0, 2, 1, 4, 5, 7])
|
||||
return table[x]
|
||||
|
||||
|
||||
def dot(x, y):
|
||||
"""Test dot"""
|
||||
return numpy.dot(x, y)
|
||||
|
||||
|
||||
def datagen(*args):
|
||||
"""Generate data from ranges"""
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func, args_dict, args_ranges",
|
||||
[
|
||||
(
|
||||
add,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
"y": ClearScalar(Integer(32, is_signed=False)),
|
||||
},
|
||||
(range(0, 8), range(1, 4)),
|
||||
),
|
||||
(
|
||||
constant_add,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 10),),
|
||||
),
|
||||
(
|
||||
add,
|
||||
{
|
||||
"x": ClearScalar(Integer(32, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8), range(1, 4)),
|
||||
),
|
||||
(
|
||||
add,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(7, 15), range(1, 5)),
|
||||
),
|
||||
(
|
||||
sub,
|
||||
{
|
||||
"x": ClearScalar(Integer(8, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(5, 10), range(2, 6)),
|
||||
),
|
||||
(
|
||||
constant_sub,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 10),),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
"y": ClearScalar(Integer(8, is_signed=False)),
|
||||
},
|
||||
(range(1, 5), range(2, 8)),
|
||||
),
|
||||
(
|
||||
constant_mul,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 10),),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
{
|
||||
"x": ClearScalar(Integer(8, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(1, 5), range(2, 8)),
|
||||
),
|
||||
(
|
||||
sub_add_mul,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
"z": ClearScalar(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(0, 8), range(1, 5), range(5, 12)),
|
||||
),
|
||||
(
|
||||
ret_multiple,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
"z": ClearScalar(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(1, 5), range(1, 5), range(1, 5)),
|
||||
),
|
||||
(
|
||||
ret_multiple_different_order,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(7, is_signed=False)),
|
||||
"z": ClearScalar(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(1, 5), range(1, 5), range(1, 5)),
|
||||
),
|
||||
(
|
||||
lut,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(3, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
(
|
||||
lut_more_bits_than_table_length,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8), range(0, 16)),
|
||||
),
|
||||
(
|
||||
lut_less_bits_than_table_length,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(3, is_signed=False)),
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter(func, args_dict, args_ranges, default_compilation_configuration):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
inputset = datagen(*args_ranges)
|
||||
result_graph = compile_numpy_function_into_op_graph(
|
||||
func,
|
||||
args_dict,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
prepare_op_graph_for_mlir(result_graph)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
compiler.round_trip(mlir_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func, args_dict, args_ranges",
|
||||
[
|
||||
(
|
||||
dot,
|
||||
{
|
||||
"x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
"y": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
},
|
||||
(range(0, 4), range(0, 4)),
|
||||
),
|
||||
(
|
||||
dot,
|
||||
{
|
||||
"x": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
"y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
},
|
||||
(range(0, 4), range(0, 4)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter_dot_between_vectors(
|
||||
func, args_dict, args_ranges, default_compilation_configuration
|
||||
):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
assert len(args_dict["x"].shape) == 1
|
||||
assert len(args_dict["y"].shape) == 1
|
||||
|
||||
n = args_dict["x"].shape[0]
|
||||
|
||||
result_graph = compile_numpy_function_into_op_graph(
|
||||
func,
|
||||
args_dict,
|
||||
(
|
||||
(numpy.array([data[0]] * n), numpy.array([data[1]] * n))
|
||||
for data in datagen(*args_ranges)
|
||||
),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
prepare_op_graph_for_mlir(result_graph)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
compiler.round_trip(mlir_result)
|
||||
|
||||
|
||||
def test_mlir_converter_dot_vector_and_constant(default_compilation_configuration):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
|
||||
def left_dot_with_constant(x):
|
||||
return numpy.dot(x, numpy.array([1, 2]))
|
||||
|
||||
def right_dot_with_constant(x):
|
||||
return numpy.dot(numpy.array([1, 2]), x)
|
||||
|
||||
left_graph = compile_numpy_function_into_op_graph(
|
||||
left_dot_with_constant,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
prepare_op_graph_for_mlir(left_graph)
|
||||
left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
left_mlir = left_converter.convert(left_graph)
|
||||
|
||||
right_graph = compile_numpy_function_into_op_graph(
|
||||
right_dot_with_constant,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
prepare_op_graph_for_mlir(right_graph)
|
||||
right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
right_mlir = right_converter.convert(right_graph)
|
||||
|
||||
# testing that this doesn't raise an error
|
||||
compiler.round_trip(left_mlir)
|
||||
compiler.round_trip(right_mlir)
|
||||
|
||||
|
||||
def test_concrete_encrypted_integer_to_mlir_type():
|
||||
"""Test conversion of EncryptedScalar into MLIR"""
|
||||
value = EncryptedScalar(Integer(7, is_signed=False))
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
eint = converter.common_value_to_mlir_type(value)
|
||||
assert eint == hlfhe.EncryptedIntegerType.get(converter.context, 7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_signed", [True, False])
|
||||
def test_concrete_clear_integer_to_mlir_type(is_signed):
|
||||
"""Test conversion of ClearScalar into MLIR"""
|
||||
value = ClearScalar(Integer(5, is_signed=is_signed))
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with converter.context:
|
||||
int_mlir = converter.common_value_to_mlir_type(value)
|
||||
if is_signed:
|
||||
assert int_mlir == IntegerType.get_signed(5)
|
||||
else:
|
||||
assert int_mlir == IntegerType.get_signless(5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_signed", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(5,),
|
||||
(5, 8),
|
||||
(-1, 5),
|
||||
],
|
||||
)
|
||||
def test_concrete_clear_tensor_integer_to_mlir_type(is_signed, shape):
|
||||
"""Test conversion of ClearTensor into MLIR"""
|
||||
value = ClearTensor(Integer(5, is_signed=is_signed), shape)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with converter.context, Location.unknown():
|
||||
tensor_mlir = converter.common_value_to_mlir_type(value)
|
||||
if is_signed:
|
||||
element_type = IntegerType.get_signed(5)
|
||||
else:
|
||||
element_type = IntegerType.get_signless(5)
|
||||
if shape is None:
|
||||
expected_type = UnrankedTensorType.get(element_type)
|
||||
else:
|
||||
expected_type = RankedTensorType.get(shape, element_type)
|
||||
assert tensor_mlir == expected_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(5,),
|
||||
(5, 8),
|
||||
(-1, 5),
|
||||
],
|
||||
)
|
||||
def test_concrete_encrypted_tensor_integer_to_mlir_type(shape):
|
||||
"""Test conversion of EncryptedTensor into MLIR"""
|
||||
value = EncryptedTensor(Integer(6, is_signed=False), shape)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with converter.context, Location.unknown():
|
||||
tensor_mlir = converter.common_value_to_mlir_type(value)
|
||||
element_type = hlfhe.EncryptedIntegerType.get(converter.context, 6)
|
||||
if shape is None:
|
||||
expected_type = UnrankedTensorType.get(element_type)
|
||||
else:
|
||||
expected_type = RankedTensorType.get(shape, element_type)
|
||||
assert tensor_mlir == expected_type
|
||||
|
||||
|
||||
def test_failing_concrete_to_mlir_type():
|
||||
"""Test failing conversion of an unsupported type into MLIR"""
|
||||
value = "random"
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with pytest.raises(TypeError, match=r"can't convert value of type .* to MLIR type"):
|
||||
converter.common_value_to_mlir_type(value)
|
||||
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
85
tests/common/mlir/test_node_converter.py
Normal file
85
tests/common/mlir/test_node_converter.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Test file for intermediate node to MLIR converter."""
|
||||
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types import UnsignedInteger
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import compile_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_compile,parameters,inputset,expected_error_type,expected_error_message",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: x * y,
|
||||
{
|
||||
"x": EncryptedScalar(UnsignedInteger(3)),
|
||||
"y": EncryptedScalar(UnsignedInteger(3)),
|
||||
},
|
||||
[(random.randint(0, 7), random.randint(0, 7)) for _ in range(10)] + [(7, 7)],
|
||||
NotImplementedError,
|
||||
"Multiplication "
|
||||
"between "
|
||||
"EncryptedScalar<uint6> "
|
||||
"and "
|
||||
"EncryptedScalar<uint6> "
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x - y,
|
||||
{
|
||||
"x": EncryptedScalar(UnsignedInteger(3)),
|
||||
"y": EncryptedScalar(UnsignedInteger(3)),
|
||||
},
|
||||
[(random.randint(5, 7), random.randint(0, 5)) for _ in range(10)],
|
||||
NotImplementedError,
|
||||
"Subtraction "
|
||||
"of "
|
||||
"EncryptedScalar<uint3> "
|
||||
"from "
|
||||
"EncryptedScalar<uint3> "
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
|
||||
"y": EncryptedTensor(UnsignedInteger(3), shape=(2,)),
|
||||
},
|
||||
[
|
||||
(
|
||||
numpy.random.randint(0, 2 ** 3, size=(2,)),
|
||||
numpy.random.randint(0, 2 ** 3, size=(2,)),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
+ [(numpy.array([7, 7]), numpy.array([7, 7]))],
|
||||
NotImplementedError,
|
||||
"Dot product "
|
||||
"between "
|
||||
"EncryptedTensor<uint7, shape=(2,)> "
|
||||
"and "
|
||||
"EncryptedTensor<uint7, shape=(2,)> "
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_node_conversion(
|
||||
function_to_compile,
|
||||
parameters,
|
||||
inputset,
|
||||
expected_error_type,
|
||||
expected_error_message,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test function for failed intermediate node conversion."""
|
||||
|
||||
with pytest.raises(expected_error_type) as excinfo:
|
||||
compile_numpy_function(
|
||||
function_to_compile, parameters, inputset, default_compilation_configuration
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == expected_error_message
|
||||
@@ -519,6 +519,8 @@ def test_compile_function_multiple_outputs(
|
||||
pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]),
|
||||
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
|
||||
pytest.param(lambda x: -x + 50, ((0, 20),), ["x"]),
|
||||
pytest.param(lambda x: numpy.dot(x, 2), ((0, 20),), ["x"]),
|
||||
pytest.param(lambda x: numpy.dot(2, x), ((0, 20),), ["x"]),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_correctness(
|
||||
@@ -548,6 +550,11 @@ def test_compile_and_run_correctness(
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
[
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{
|
||||
@@ -566,6 +573,7 @@ def test_compile_and_run_correctness(
|
||||
[7, 2],
|
||||
[3, 6],
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
|
||||
@@ -590,9 +598,7 @@ def test_compile_and_run_correctness(
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars can be converted to (1,) shaped tensors
|
||||
# this is easy with known constants but weird with variable things such as another input
|
||||
# there is tensor.from_elements but I coudn't figure out how to use it in the python API
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{
|
||||
@@ -619,7 +625,7 @@ def test_compile_and_run_correctness(
|
||||
[8, 3],
|
||||
[4, 7],
|
||||
],
|
||||
marks=pytest.mark.xfail(),
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
@@ -652,6 +658,11 @@ def test_compile_and_run_correctness(
|
||||
[5, 9],
|
||||
],
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: 100 - x,
|
||||
{
|
||||
@@ -670,6 +681,7 @@ def test_compile_and_run_correctness(
|
||||
[94, 99],
|
||||
[98, 95],
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
|
||||
@@ -690,6 +702,11 @@ def test_compile_and_run_correctness(
|
||||
[8, 25],
|
||||
],
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: x * 2,
|
||||
{
|
||||
@@ -708,6 +725,7 @@ def test_compile_and_run_correctness(
|
||||
[12, 2],
|
||||
[4, 10],
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
|
||||
@@ -747,6 +765,36 @@ def test_compile_and_run_correctness(
|
||||
[0, 2],
|
||||
],
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: numpy.dot(x, 2),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
|
||||
([2, 7, 1],),
|
||||
[4, 14, 2],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars should be converted to (1,) shaped tensors
|
||||
pytest.param(
|
||||
lambda x: numpy.dot(2, x),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
|
||||
([2, 7, 1],),
|
||||
[4, 14, 2],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_tensor_correctness(
|
||||
@@ -874,7 +922,7 @@ def test_compile_and_run_constant_dot_correctness(
|
||||
default_compilation_configuration,
|
||||
)
|
||||
right_circuit = compile_numpy_function(
|
||||
left,
|
||||
right,
|
||||
{"x": EncryptedTensor(Integer(64, False), shape)},
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
|
||||
Reference in New Issue
Block a user