mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(mlir): MLIR Conversion (#103)
* feat(mlir): conversion from HDKIR to MLIR * feat(mlir): support ir.Sub and ir.Mul - better type conversion from HDK to MLIR - Context management inside the converter class - better handling of input type in conversion functions * refactor(mlir): use input and output from OPGraph Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai> * feat(mlir): eint-int subtractions * feat(mlir): adhere to spec for supported ops * feat(OPGraph): getters for ordered inputs/outputs + formatting * tests(mlir): test converion via compiler roundtrip * fix(mlir): flip operands on int_eint sym ops * feat(mlir): check that the outputs are unsigned * feat(mlir): set bit_width of all nodes to the max This is currently required as the compiler is already assuming this. Could be removed from HDK when the compiler can do it on its own * feat: value_is_integer + CRs disable some linting errors * tests: update compile tests + coverage * refactor: reorganize mlir package + better doc * doc: conformance with pydocstyle Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
"""Module for data types code and data structures."""
|
||||
from . import dtypes_helpers, integers, values
|
||||
from .integers import Integer
|
||||
from .values import BaseValue
|
||||
|
||||
@@ -34,7 +34,7 @@ def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is an encrypted value of type Integer
|
||||
bool: True if the passed value_to_check is an encrypted value of type Integer and unsigned
|
||||
"""
|
||||
return (
|
||||
value_is_encrypted_integer(value_to_check)
|
||||
@@ -42,6 +42,32 @@ def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def value_is_clear_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is a clear integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a clear value of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check, ClearValue) and isinstance(
|
||||
value_to_check.data_type, INTEGER_TYPES
|
||||
)
|
||||
|
||||
|
||||
def value_is_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Helper function to check that a value is of type integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is a value of type Integer
|
||||
"""
|
||||
return isinstance(value_to_check.data_type, INTEGER_TYPES)
|
||||
|
||||
|
||||
def find_type_to_hold_both_lossy(
|
||||
dtype1: BaseDataType,
|
||||
dtype2: BaseDataType,
|
||||
|
||||
5
hdk/common/mlir/__init__.py
Normal file
5
hdk/common/mlir/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""MLIR conversion submodule."""
|
||||
from .converters import V0_OPSET_CONVERSION_FUNCTIONS
|
||||
from .mlir_converter import MLIRConverter
|
||||
|
||||
__all__ = ["MLIRConverter", "V0_OPSET_CONVERSION_FUNCTIONS"]
|
||||
118
hdk/common/mlir/converters.py
Normal file
118
hdk/common/mlir/converters.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Converter functions from HDKIR to MLIR.
|
||||
|
||||
Converter functions all have the same signature `converter(node, preds, ir_to_mlir_node, ctx)`
|
||||
- `node`: IntermediateNode to be converted
|
||||
- `preds`: List of predecessors of `node` ordered as operands
|
||||
- `ir_to_mlir_node`: Dict mapping intermediate nodes to MLIR nodes or values
|
||||
- `ctx`: MLIR context
|
||||
"""
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
)
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
|
||||
def add(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the addition intermediate node."""
|
||||
assert len(node.inputs) == 2, "addition should have two inputs"
|
||||
assert len(node.outputs) == 1, "addition should have a single output"
|
||||
if value_is_encrypted_unsigned_integer(node.inputs[0]) and value_is_clear_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _add_eint_int(node, preds, ir_to_mlir_node, ctx)
|
||||
if value_is_encrypted_unsigned_integer(node.inputs[1]) and value_is_clear_integer(
|
||||
node.inputs[0]
|
||||
):
|
||||
# flip lhs and rhs
|
||||
return _add_eint_int(node, preds[::-1], ir_to_mlir_node, ctx)
|
||||
if value_is_encrypted_unsigned_integer(node.inputs[0]) and value_is_encrypted_unsigned_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _add_eint_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
raise TypeError(
|
||||
f"Don't support addition between {type(node.inputs[0])} and {type(node.inputs[1])}"
|
||||
)
|
||||
|
||||
|
||||
def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the addition intermediate node with operands (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.AddEintIntOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the addition intermediate node with operands (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.AddEintOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
def sub(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the subtraction intermediate node."""
|
||||
assert len(node.inputs) == 2, "subtraction should have two inputs"
|
||||
assert len(node.outputs) == 1, "subtraction should have a single output"
|
||||
if value_is_clear_integer(node.inputs[0]) and value_is_encrypted_unsigned_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _sub_int_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
raise TypeError(
|
||||
f"Don't support subtraction between {type(node.inputs[0])} and {type(node.inputs[1])}"
|
||||
)
|
||||
|
||||
|
||||
def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the subtraction intermediate node with operands (int, eint)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.SubIntEintOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
def mul(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the multiplication intermediate node."""
|
||||
assert len(node.inputs) == 2, "multiplication should have two inputs"
|
||||
assert len(node.outputs) == 1, "multiplication should have a single output"
|
||||
if value_is_encrypted_unsigned_integer(node.inputs[0]) and value_is_clear_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _mul_eint_int(node, preds, ir_to_mlir_node, ctx)
|
||||
if value_is_encrypted_unsigned_integer(node.inputs[1]) and value_is_clear_integer(
|
||||
node.inputs[0]
|
||||
):
|
||||
# flip lhs and rhs
|
||||
return _mul_eint_int(node, preds[::-1], ir_to_mlir_node, ctx)
|
||||
raise TypeError(
|
||||
f"Don't support multiplication between {type(node.inputs[0])} and {type(node.inputs[1])}"
|
||||
)
|
||||
|
||||
|
||||
def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
|
||||
"""Converter function for the multiplication intermediate node with operands (eint, int)."""
|
||||
lhs_node, rhs_node = preds
|
||||
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
|
||||
return hlfhe.MulEintIntOp(
|
||||
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
|
||||
lhs,
|
||||
rhs,
|
||||
).result
|
||||
|
||||
|
||||
V0_OPSET_CONVERSION_FUNCTIONS = {ir.Add: add, ir.Sub: sub, ir.Mul: mul}
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
117
hdk/common/mlir/mlir_converter.py
Normal file
117
hdk/common/mlir/mlir_converter.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""File containing code to convert a DAG containing ir nodes to the compiler opset."""
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
from typing import cast
|
||||
|
||||
import networkx as nx
|
||||
import zamalang
|
||||
from mlir.dialects import builtin
|
||||
from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module
|
||||
from mlir.ir import Type as MLIRType
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from .. import data_types
|
||||
from ..data_types import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_integer,
|
||||
value_is_encrypted_unsigned_integer,
|
||||
)
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
|
||||
class MLIRConverter:
|
||||
"""Converter of the HDKIR to MLIR."""
|
||||
|
||||
def __init__(self, conversion_functions: dict) -> None:
|
||||
"""Instantiate a converter with a given set of converters.
|
||||
|
||||
Args:
|
||||
conversion_functions (dict): mapping HDKIR nodes to functions that generate MLIR.
|
||||
every function should have 4 arguments:
|
||||
- node (IntermediateNode): the node itself to be converted
|
||||
- operands (IntermediateNode): predecessors of node ordered as operands
|
||||
- ir_to_mlir_node (dict): mapping between IntermediateNode and their equivalent
|
||||
MLIR values
|
||||
- context (mlir.Context): the MLIR context being used for the conversion
|
||||
"""
|
||||
self.conversion_functions = conversion_functions
|
||||
self._init_context()
|
||||
|
||||
def _init_context(self):
|
||||
self.context = Context()
|
||||
zamalang.register_dialects(self.context)
|
||||
|
||||
def hdk_value_to_mlir_type(self, value: data_types.BaseValue) -> MLIRType:
|
||||
"""Convert an HDK value to its corresponding MLIR Type.
|
||||
|
||||
Args:
|
||||
value: value to convert
|
||||
|
||||
Returns:
|
||||
corresponding MLIR type
|
||||
"""
|
||||
if value_is_encrypted_unsigned_integer(value):
|
||||
return hlfhe.EncryptedIntegerType.get(
|
||||
self.context, cast(Integer, value.data_type).bit_width
|
||||
)
|
||||
if value_is_clear_integer(value):
|
||||
dtype = cast(Integer, value.data_type)
|
||||
if dtype.is_signed:
|
||||
return IntegerType.get_signed(dtype.bit_width, context=self.context)
|
||||
return IntegerType.get_unsigned(dtype.bit_width, context=self.context)
|
||||
raise TypeError(f"can't convert value of type {type(value)} to MLIR type")
|
||||
|
||||
def convert(self, op_graph: OPGraph) -> str:
|
||||
"""Convert the graph of IntermediateNode to an MLIR textual representation.
|
||||
|
||||
Args:
|
||||
graph: graph of IntermediateNode to be converted
|
||||
|
||||
Returns:
|
||||
textual MLIR representation
|
||||
"""
|
||||
with self.context, Location.unknown():
|
||||
module = Module.create()
|
||||
# collect inputs
|
||||
with InsertionPoint(module.body):
|
||||
func_types = [
|
||||
self.hdk_value_to_mlir_type(input_node.inputs[0])
|
||||
for input_node in op_graph.get_ordered_inputs()
|
||||
]
|
||||
|
||||
@builtin.FuncOp.from_py_func(*func_types)
|
||||
def fhe_circuit(*arg):
|
||||
ir_to_mlir_node = {}
|
||||
for arg_num, node in op_graph.input_nodes.items():
|
||||
ir_to_mlir_node[node] = arg[arg_num]
|
||||
for node in nx.topological_sort(op_graph.graph):
|
||||
if isinstance(node, ir.Input):
|
||||
continue
|
||||
mlir_op = self.conversion_functions.get(type(node), None)
|
||||
if mlir_op is None: # pragma: no cover
|
||||
raise NotImplementedError(
|
||||
f"we don't yet support conversion to MLIR of computations using"
|
||||
f"{type(node)}"
|
||||
)
|
||||
# get sorted preds: sorted by their input index
|
||||
# replication of pred is possible (e.g lambda x: x + x)
|
||||
idx_to_pred = {}
|
||||
for pred in op_graph.graph.pred[node]:
|
||||
edge_data = op_graph.graph.get_edge_data(pred, node)
|
||||
for data in edge_data.values():
|
||||
idx_to_pred[data["input_idx"]] = pred
|
||||
preds = [idx_to_pred[i] for i in range(len(idx_to_pred))]
|
||||
# convert to mlir
|
||||
result = mlir_op(node, preds, ir_to_mlir_node, self.context)
|
||||
ir_to_mlir_node[node] = result
|
||||
|
||||
results = (
|
||||
ir_to_mlir_node[output_node]
|
||||
for output_node in op_graph.get_ordered_outputs()
|
||||
)
|
||||
return results
|
||||
|
||||
return module.__str__()
|
||||
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
59
hdk/common/mlir/utils.py
Normal file
59
hdk/common/mlir/utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Utilities for MLIR conversion."""
|
||||
from typing import cast
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_integer,
|
||||
value_is_encrypted_integer,
|
||||
value_is_integer,
|
||||
)
|
||||
from ..operator_graph import OPGraph
|
||||
|
||||
|
||||
def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
|
||||
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
|
||||
|
||||
Args:
|
||||
op_graph: computation graph to check
|
||||
|
||||
Returns:
|
||||
bool: is the graph compatible with the expected MLIR representation
|
||||
"""
|
||||
return all(
|
||||
all(
|
||||
value_is_integer(out) and not cast(Integer, out.data_type).is_signed
|
||||
for out in out_node.outputs
|
||||
)
|
||||
for out_node in op_graph.output_nodes.values()
|
||||
)
|
||||
|
||||
|
||||
def _set_all_bit_width(op_graph: OPGraph, p: int):
|
||||
"""Set all bit_width in the graph to `p` and `p+1` for clear and encrypted values respectively.
|
||||
|
||||
Args:
|
||||
op_graph: graph to set bit_width for
|
||||
p: bit_width to set everywhere
|
||||
"""
|
||||
for node in op_graph.graph.nodes:
|
||||
for value in node.outputs + node.inputs:
|
||||
if value_is_clear_integer(value):
|
||||
value.data_type.bit_width = p + 1
|
||||
elif value_is_encrypted_integer(value):
|
||||
value.data_type.bit_width = p
|
||||
|
||||
|
||||
def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
"""Prepare bit_width of all nodes to be the same, set to the maximum value in the graph.
|
||||
|
||||
Args:
|
||||
op_graph: graph to update bit_width for
|
||||
"""
|
||||
max_bit_width = 0
|
||||
for node in op_graph.graph.nodes:
|
||||
for value_out in node.outputs:
|
||||
if value_is_clear_integer(value_out):
|
||||
max_bit_width = max(max_bit_width, value_out.data_type.bit_width - 1)
|
||||
elif value_is_encrypted_integer(value_out):
|
||||
max_bit_width = max(max_bit_width, value_out.data_type.bit_width)
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Code to wrap and make manipulating networkx graphs easier."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, Mapping
|
||||
from typing import Any, Dict, Iterable, List, Mapping
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@@ -31,6 +31,22 @@ class OPGraph:
|
||||
if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input)
|
||||
}
|
||||
|
||||
def get_ordered_inputs(self) -> List[ir.Input]:
|
||||
"""Get the input nodes of the graph, ordered by their index.
|
||||
|
||||
Returns:
|
||||
List[ir.Input]: ordered input nodes
|
||||
"""
|
||||
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
|
||||
|
||||
def get_ordered_outputs(self) -> List[ir.IntermediateNode]:
|
||||
"""Get the output nodes of the graph, ordered by their index.
|
||||
|
||||
Returns:
|
||||
List[ir.IntermediateNode]: ordered input nodes
|
||||
"""
|
||||
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]:
|
||||
"""Function to evaluate a graph and get intermediate values for all nodes.
|
||||
|
||||
@@ -69,7 +85,10 @@ class OPGraph:
|
||||
|
||||
for node in self.graph.nodes():
|
||||
current_node_bounds = node_bounds[node]
|
||||
min_bound, max_bound = current_node_bounds["min"], current_node_bounds["max"]
|
||||
min_bound, max_bound = (
|
||||
current_node_bounds["min"],
|
||||
current_node_bounds["max"],
|
||||
)
|
||||
|
||||
if not isinstance(node, ir.Input):
|
||||
for output_value in node.outputs:
|
||||
|
||||
@@ -7,6 +7,10 @@ from hdk.hnumpy.tracing import trace_numpy_function
|
||||
|
||||
from ..common.compilation import CompilationArtifacts
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.mlir.utils import (
|
||||
is_graph_values_compatible_with_mlir,
|
||||
update_bit_width_for_mlir,
|
||||
)
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..hnumpy.tracing import trace_numpy_function
|
||||
|
||||
@@ -42,6 +46,13 @@ def compile_numpy_function(
|
||||
# Update the graph accordingly: after that, we have the compilable graph
|
||||
op_graph.update_values_with_bounds(node_bounds)
|
||||
|
||||
# Make sure the graph can be lowered to MLIR
|
||||
if not is_graph_values_compatible_with_mlir(op_graph):
|
||||
raise TypeError("signed integers aren't supported for MLIR lowering")
|
||||
|
||||
# Update bit_width for MLIR
|
||||
update_bit_width_for_mlir(op_graph)
|
||||
|
||||
# Fill compilation artifacts
|
||||
if compilation_artifacts is not None:
|
||||
compilation_artifacts.operation_graph = op_graph
|
||||
|
||||
19
tests/common/mlir/test_converters.py
Normal file
19
tests/common/mlir/test_converters.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Test converter functions"""
|
||||
import pytest
|
||||
|
||||
from hdk.common.mlir.converters import add, mul, sub
|
||||
|
||||
|
||||
class MockNode:
|
||||
"""Mocking an intermediate node"""
|
||||
|
||||
def __init__(self, inputs=5, outputs=5):
|
||||
self.inputs = [None for i in range(inputs)]
|
||||
self.outputs = [None for i in range(outputs)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("converter", [add, sub, mul])
|
||||
def test_failing_converter(converter):
|
||||
"""Test failing converter"""
|
||||
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
|
||||
converter(MockNode(2, 1), None, None, None)
|
||||
171
tests/common/mlir/test_mlir_converter.py
Normal file
171
tests/common/mlir/test_mlir_converter.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Test file for conversion to MLIR"""
|
||||
# pylint: disable=no-name-in-module,no-member
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
from mlir.ir import IntegerType
|
||||
from zamalang import compiler
|
||||
from zamalang.dialects import hlfhe
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
||||
from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from hdk.hnumpy.compile import compile_numpy_function
|
||||
|
||||
|
||||
def add(x, y):
|
||||
"""Test simple add"""
|
||||
return x + y
|
||||
|
||||
|
||||
def sub(x, y):
|
||||
"""Test simple sub"""
|
||||
return x - y
|
||||
|
||||
|
||||
def mul(x, y):
|
||||
"""Test simple mul"""
|
||||
return x * y
|
||||
|
||||
|
||||
def sub_add_mul(x, y, z):
|
||||
"""Test combination of ops"""
|
||||
return z - y + x * z
|
||||
|
||||
|
||||
def ret_multiple(x, y, z):
|
||||
"""Test return of multiple values"""
|
||||
return x, y, z
|
||||
|
||||
|
||||
def ret_multiple_different_order(x, y, z):
|
||||
"""Test return of multiple values in a different order from input"""
|
||||
return y, z, x
|
||||
|
||||
|
||||
def datagen(*args):
|
||||
"""Generate data from ranges"""
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func, args_dict, args_ranges",
|
||||
[
|
||||
(
|
||||
add,
|
||||
{
|
||||
"x": EncryptedValue(Integer(64, is_signed=False)),
|
||||
"y": ClearValue(Integer(32, is_signed=False)),
|
||||
},
|
||||
(range(0, 8), range(1, 4)),
|
||||
),
|
||||
(
|
||||
add,
|
||||
{
|
||||
"x": ClearValue(Integer(32, is_signed=False)),
|
||||
"y": EncryptedValue(Integer(64, is_signed=False)),
|
||||
},
|
||||
(range(0, 8), range(1, 4)),
|
||||
),
|
||||
(
|
||||
add,
|
||||
{
|
||||
"x": EncryptedValue(Integer(7, is_signed=False)),
|
||||
"y": EncryptedValue(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(7, 15), range(1, 5)),
|
||||
),
|
||||
(
|
||||
sub,
|
||||
{
|
||||
"x": ClearValue(Integer(8, is_signed=False)),
|
||||
"y": EncryptedValue(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(5, 10), range(2, 6)),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
{
|
||||
"x": EncryptedValue(Integer(7, is_signed=False)),
|
||||
"y": ClearValue(Integer(8, is_signed=False)),
|
||||
},
|
||||
(range(1, 5), range(2, 8)),
|
||||
),
|
||||
(
|
||||
mul,
|
||||
{
|
||||
"x": ClearValue(Integer(8, is_signed=False)),
|
||||
"y": EncryptedValue(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(1, 5), range(2, 8)),
|
||||
),
|
||||
(
|
||||
sub_add_mul,
|
||||
{
|
||||
"x": EncryptedValue(Integer(7, is_signed=False)),
|
||||
"y": EncryptedValue(Integer(7, is_signed=False)),
|
||||
"z": ClearValue(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(0, 8), range(1, 5), range(5, 12)),
|
||||
),
|
||||
(
|
||||
ret_multiple,
|
||||
{
|
||||
"x": EncryptedValue(Integer(7, is_signed=False)),
|
||||
"y": EncryptedValue(Integer(7, is_signed=False)),
|
||||
"z": ClearValue(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(1, 5), range(1, 5), range(1, 5)),
|
||||
),
|
||||
(
|
||||
ret_multiple_different_order,
|
||||
{
|
||||
"x": EncryptedValue(Integer(7, is_signed=False)),
|
||||
"y": EncryptedValue(Integer(7, is_signed=False)),
|
||||
"z": ClearValue(Integer(7, is_signed=False)),
|
||||
},
|
||||
(range(1, 5), range(1, 5), range(1, 5)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter(func, args_dict, args_ranges):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
dataset = datagen(*args_ranges)
|
||||
result_graph = compile_numpy_function(func, args_dict, dataset)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
compiler.round_trip(mlir_result)
|
||||
|
||||
|
||||
def test_hdk_encrypted_integer_to_mlir_type():
|
||||
"""Test conversion of EncryptedValue into MLIR"""
|
||||
value = EncryptedValue(Integer(7, is_signed=False))
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
eint = converter.hdk_value_to_mlir_type(value)
|
||||
assert eint == hlfhe.EncryptedIntegerType.get(converter.context, 7)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_signed", [True, False])
|
||||
def test_hdk_clear_integer_to_mlir_type(is_signed):
|
||||
"""Test conversion of ClearValue into MLIR"""
|
||||
value = ClearValue(Integer(5, is_signed=is_signed))
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
int_mlir = converter.hdk_value_to_mlir_type(value)
|
||||
with converter.context:
|
||||
if is_signed:
|
||||
assert int_mlir == IntegerType.get_signed(5)
|
||||
else:
|
||||
assert int_mlir == IntegerType.get_unsigned(5)
|
||||
|
||||
|
||||
def test_failing_hdk_to_mlir_type():
|
||||
"""Test failing conversion of an unsupported type into MLIR"""
|
||||
value = "random"
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
with pytest.raises(TypeError, match=r"can't convert value of type .* to MLIR type"):
|
||||
converter.hdk_value_to_mlir_type(value)
|
||||
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
@@ -14,11 +14,11 @@ from hdk.hnumpy.compile import compile_numpy_function
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 42, ((-2, 2),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((-10, 2), (-4, 6)), ["x", "y"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
|
||||
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 4)), ["x", "y"]),
|
||||
pytest.param(
|
||||
lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99),
|
||||
((-1, 1), (3, 4), (10, 20)),
|
||||
((4, 8), (3, 4), (0, 4)),
|
||||
["x", "y", "z"],
|
||||
),
|
||||
],
|
||||
@@ -80,3 +80,28 @@ def test_compile_function_with_direct_tlu_overflow():
|
||||
{"x": EncryptedValue(Integer(3, is_signed=False))},
|
||||
iter([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x - 10, ((-2, 2),), ["x"]),
|
||||
],
|
||||
)
|
||||
def test_fail_compile(function, input_ranges, list_of_arg_names):
|
||||
"""Test function compile_numpy_function for a program with signed values"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
with pytest.raises(TypeError, match=r"signed integers aren't supported for MLIR lowering"):
|
||||
compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user