feat(mlir): MLIR Conversion (#103)

* feat(mlir): conversion from HDKIR to MLIR

* feat(mlir): support ir.Sub and ir.Mul

- better type conversion from HDK to MLIR
- Context management inside the converter class
- better handling of input type in conversion functions

* refactor(mlir): use input and output from OPGraph

Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai>

* feat(mlir): eint-int subtractions

* feat(mlir): adhere to spec for supported ops

* feat(OPGraph): getters for ordered inputs/outputs

+ formatting

* tests(mlir): test converion via compiler roundtrip

* fix(mlir): flip operands on int_eint sym ops

* feat(mlir): check that the outputs are unsigned

* feat(mlir): set bit_width of all nodes to the max

This is currently required as the compiler is already assuming this.
Could be removed from HDK when the compiler can do it on its own

* feat: value_is_integer + CRs

disable some linting errors

* tests: update compile tests + coverage

* refactor: reorganize mlir package + better doc

* doc: conformance with pydocstyle

Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai>
This commit is contained in:
Ayoub Benaissa
2021-08-13 12:50:31 +01:00
committed by GitHub
parent 2c3c080923
commit f6c9618b5a
11 changed files with 576 additions and 5 deletions

View File

@@ -1,3 +1,4 @@
"""Module for data types code and data structures."""
from . import dtypes_helpers, integers, values
from .integers import Integer
from .values import BaseValue

View File

@@ -34,7 +34,7 @@ def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is an encrypted value of type Integer
bool: True if the passed value_to_check is an encrypted value of type Integer and unsigned
"""
return (
value_is_encrypted_integer(value_to_check)
@@ -42,6 +42,32 @@ def value_is_encrypted_unsigned_integer(value_to_check: BaseValue) -> bool:
)
def value_is_clear_integer(value_to_check: BaseValue) -> bool:
"""Helper function to check that a value is a clear integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a clear value of type Integer
"""
return isinstance(value_to_check, ClearValue) and isinstance(
value_to_check.data_type, INTEGER_TYPES
)
def value_is_integer(value_to_check: BaseValue) -> bool:
"""Helper function to check that a value is of type integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is a value of type Integer
"""
return isinstance(value_to_check.data_type, INTEGER_TYPES)
def find_type_to_hold_both_lossy(
dtype1: BaseDataType,
dtype2: BaseDataType,

View File

@@ -0,0 +1,5 @@
"""MLIR conversion submodule."""
from .converters import V0_OPSET_CONVERSION_FUNCTIONS
from .mlir_converter import MLIRConverter
__all__ = ["MLIRConverter", "V0_OPSET_CONVERSION_FUNCTIONS"]

View File

@@ -0,0 +1,118 @@
"""Converter functions from HDKIR to MLIR.
Converter functions all have the same signature `converter(node, preds, ir_to_mlir_node, ctx)`
- `node`: IntermediateNode to be converted
- `preds`: List of predecessors of `node` ordered as operands
- `ir_to_mlir_node`: Dict mapping intermediate nodes to MLIR nodes or values
- `ctx`: MLIR context
"""
# pylint: disable=no-name-in-module,no-member
from zamalang.dialects import hlfhe
from ..data_types.dtypes_helpers import (
value_is_clear_integer,
value_is_encrypted_unsigned_integer,
)
from ..representation import intermediate as ir
def add(node, preds, ir_to_mlir_node, ctx):
"""Converter function for the addition intermediate node."""
assert len(node.inputs) == 2, "addition should have two inputs"
assert len(node.outputs) == 1, "addition should have a single output"
if value_is_encrypted_unsigned_integer(node.inputs[0]) and value_is_clear_integer(
node.inputs[1]
):
return _add_eint_int(node, preds, ir_to_mlir_node, ctx)
if value_is_encrypted_unsigned_integer(node.inputs[1]) and value_is_clear_integer(
node.inputs[0]
):
# flip lhs and rhs
return _add_eint_int(node, preds[::-1], ir_to_mlir_node, ctx)
if value_is_encrypted_unsigned_integer(node.inputs[0]) and value_is_encrypted_unsigned_integer(
node.inputs[1]
):
return _add_eint_eint(node, preds, ir_to_mlir_node, ctx)
raise TypeError(
f"Don't support addition between {type(node.inputs[0])} and {type(node.inputs[1])}"
)
def _add_eint_int(node, preds, ir_to_mlir_node, ctx):
"""Converter function for the addition intermediate node with operands (eint, int)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.AddEintIntOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
lhs,
rhs,
).result
def _add_eint_eint(node, preds, ir_to_mlir_node, ctx):
"""Converter function for the addition intermediate node with operands (eint, int)."""
lhs_node, rhs_node = preds
lhs, rhs = lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.AddEintOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
lhs,
rhs,
).result
def sub(node, preds, ir_to_mlir_node, ctx):
"""Converter function for the subtraction intermediate node."""
assert len(node.inputs) == 2, "subtraction should have two inputs"
assert len(node.outputs) == 1, "subtraction should have a single output"
if value_is_clear_integer(node.inputs[0]) and value_is_encrypted_unsigned_integer(
node.inputs[1]
):
return _sub_int_eint(node, preds, ir_to_mlir_node, ctx)
raise TypeError(
f"Don't support subtraction between {type(node.inputs[0])} and {type(node.inputs[1])}"
)
def _sub_int_eint(node, preds, ir_to_mlir_node, ctx):
"""Converter function for the subtraction intermediate node with operands (int, eint)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.SubIntEintOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
lhs,
rhs,
).result
def mul(node, preds, ir_to_mlir_node, ctx):
"""Converter function for the multiplication intermediate node."""
assert len(node.inputs) == 2, "multiplication should have two inputs"
assert len(node.outputs) == 1, "multiplication should have a single output"
if value_is_encrypted_unsigned_integer(node.inputs[0]) and value_is_clear_integer(
node.inputs[1]
):
return _mul_eint_int(node, preds, ir_to_mlir_node, ctx)
if value_is_encrypted_unsigned_integer(node.inputs[1]) and value_is_clear_integer(
node.inputs[0]
):
# flip lhs and rhs
return _mul_eint_int(node, preds[::-1], ir_to_mlir_node, ctx)
raise TypeError(
f"Don't support multiplication between {type(node.inputs[0])} and {type(node.inputs[1])}"
)
def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
"""Converter function for the multiplication intermediate node with operands (eint, int)."""
lhs_node, rhs_node = preds
lhs, rhs = ir_to_mlir_node[lhs_node], ir_to_mlir_node[rhs_node]
return hlfhe.MulEintIntOp(
hlfhe.EncryptedIntegerType.get(ctx, node.outputs[0].data_type.bit_width),
lhs,
rhs,
).result
V0_OPSET_CONVERSION_FUNCTIONS = {ir.Add: add, ir.Sub: sub, ir.Mul: mul}
# pylint: enable=no-name-in-module,no-member

View File

@@ -0,0 +1,117 @@
"""File containing code to convert a DAG containing ir nodes to the compiler opset."""
# pylint: disable=no-name-in-module,no-member
from typing import cast
import networkx as nx
import zamalang
from mlir.dialects import builtin
from mlir.ir import Context, InsertionPoint, IntegerType, Location, Module
from mlir.ir import Type as MLIRType
from zamalang.dialects import hlfhe
from .. import data_types
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_integer,
value_is_encrypted_unsigned_integer,
)
from ..operator_graph import OPGraph
from ..representation import intermediate as ir
class MLIRConverter:
"""Converter of the HDKIR to MLIR."""
def __init__(self, conversion_functions: dict) -> None:
"""Instantiate a converter with a given set of converters.
Args:
conversion_functions (dict): mapping HDKIR nodes to functions that generate MLIR.
every function should have 4 arguments:
- node (IntermediateNode): the node itself to be converted
- operands (IntermediateNode): predecessors of node ordered as operands
- ir_to_mlir_node (dict): mapping between IntermediateNode and their equivalent
MLIR values
- context (mlir.Context): the MLIR context being used for the conversion
"""
self.conversion_functions = conversion_functions
self._init_context()
def _init_context(self):
self.context = Context()
zamalang.register_dialects(self.context)
def hdk_value_to_mlir_type(self, value: data_types.BaseValue) -> MLIRType:
"""Convert an HDK value to its corresponding MLIR Type.
Args:
value: value to convert
Returns:
corresponding MLIR type
"""
if value_is_encrypted_unsigned_integer(value):
return hlfhe.EncryptedIntegerType.get(
self.context, cast(Integer, value.data_type).bit_width
)
if value_is_clear_integer(value):
dtype = cast(Integer, value.data_type)
if dtype.is_signed:
return IntegerType.get_signed(dtype.bit_width, context=self.context)
return IntegerType.get_unsigned(dtype.bit_width, context=self.context)
raise TypeError(f"can't convert value of type {type(value)} to MLIR type")
def convert(self, op_graph: OPGraph) -> str:
"""Convert the graph of IntermediateNode to an MLIR textual representation.
Args:
graph: graph of IntermediateNode to be converted
Returns:
textual MLIR representation
"""
with self.context, Location.unknown():
module = Module.create()
# collect inputs
with InsertionPoint(module.body):
func_types = [
self.hdk_value_to_mlir_type(input_node.inputs[0])
for input_node in op_graph.get_ordered_inputs()
]
@builtin.FuncOp.from_py_func(*func_types)
def fhe_circuit(*arg):
ir_to_mlir_node = {}
for arg_num, node in op_graph.input_nodes.items():
ir_to_mlir_node[node] = arg[arg_num]
for node in nx.topological_sort(op_graph.graph):
if isinstance(node, ir.Input):
continue
mlir_op = self.conversion_functions.get(type(node), None)
if mlir_op is None: # pragma: no cover
raise NotImplementedError(
f"we don't yet support conversion to MLIR of computations using"
f"{type(node)}"
)
# get sorted preds: sorted by their input index
# replication of pred is possible (e.g lambda x: x + x)
idx_to_pred = {}
for pred in op_graph.graph.pred[node]:
edge_data = op_graph.graph.get_edge_data(pred, node)
for data in edge_data.values():
idx_to_pred[data["input_idx"]] = pred
preds = [idx_to_pred[i] for i in range(len(idx_to_pred))]
# convert to mlir
result = mlir_op(node, preds, ir_to_mlir_node, self.context)
ir_to_mlir_node[node] = result
results = (
ir_to_mlir_node[output_node]
for output_node in op_graph.get_ordered_outputs()
)
return results
return module.__str__()
# pylint: enable=no-name-in-module,no-member

59
hdk/common/mlir/utils.py Normal file
View File

@@ -0,0 +1,59 @@
"""Utilities for MLIR conversion."""
from typing import cast
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_integer,
value_is_encrypted_integer,
value_is_integer,
)
from ..operator_graph import OPGraph
def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
Args:
op_graph: computation graph to check
Returns:
bool: is the graph compatible with the expected MLIR representation
"""
return all(
all(
value_is_integer(out) and not cast(Integer, out.data_type).is_signed
for out in out_node.outputs
)
for out_node in op_graph.output_nodes.values()
)
def _set_all_bit_width(op_graph: OPGraph, p: int):
"""Set all bit_width in the graph to `p` and `p+1` for clear and encrypted values respectively.
Args:
op_graph: graph to set bit_width for
p: bit_width to set everywhere
"""
for node in op_graph.graph.nodes:
for value in node.outputs + node.inputs:
if value_is_clear_integer(value):
value.data_type.bit_width = p + 1
elif value_is_encrypted_integer(value):
value.data_type.bit_width = p
def update_bit_width_for_mlir(op_graph: OPGraph):
"""Prepare bit_width of all nodes to be the same, set to the maximum value in the graph.
Args:
op_graph: graph to update bit_width for
"""
max_bit_width = 0
for node in op_graph.graph.nodes:
for value_out in node.outputs:
if value_is_clear_integer(value_out):
max_bit_width = max(max_bit_width, value_out.data_type.bit_width - 1)
elif value_is_encrypted_integer(value_out):
max_bit_width = max(max_bit_width, value_out.data_type.bit_width)
_set_all_bit_width(op_graph, max_bit_width)

View File

@@ -1,7 +1,7 @@
"""Code to wrap and make manipulating networkx graphs easier."""
from copy import deepcopy
from typing import Any, Dict, Iterable, Mapping
from typing import Any, Dict, Iterable, List, Mapping
import networkx as nx
@@ -31,6 +31,22 @@ class OPGraph:
if len(self.graph.pred[node]) == 0 and isinstance(node, ir.Input)
}
def get_ordered_inputs(self) -> List[ir.Input]:
"""Get the input nodes of the graph, ordered by their index.
Returns:
List[ir.Input]: ordered input nodes
"""
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
def get_ordered_outputs(self) -> List[ir.IntermediateNode]:
"""Get the output nodes of the graph, ordered by their index.
Returns:
List[ir.IntermediateNode]: ordered input nodes
"""
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
def evaluate(self, inputs: Mapping[int, Any]) -> Dict[ir.IntermediateNode, Any]:
"""Function to evaluate a graph and get intermediate values for all nodes.
@@ -69,7 +85,10 @@ class OPGraph:
for node in self.graph.nodes():
current_node_bounds = node_bounds[node]
min_bound, max_bound = current_node_bounds["min"], current_node_bounds["max"]
min_bound, max_bound = (
current_node_bounds["min"],
current_node_bounds["max"],
)
if not isinstance(node, ir.Input):
for output_value in node.outputs:

View File

@@ -7,6 +7,10 @@ from hdk.hnumpy.tracing import trace_numpy_function
from ..common.compilation import CompilationArtifacts
from ..common.data_types import BaseValue
from ..common.mlir.utils import (
is_graph_values_compatible_with_mlir,
update_bit_width_for_mlir,
)
from ..common.operator_graph import OPGraph
from ..hnumpy.tracing import trace_numpy_function
@@ -42,6 +46,13 @@ def compile_numpy_function(
# Update the graph accordingly: after that, we have the compilable graph
op_graph.update_values_with_bounds(node_bounds)
# Make sure the graph can be lowered to MLIR
if not is_graph_values_compatible_with_mlir(op_graph):
raise TypeError("signed integers aren't supported for MLIR lowering")
# Update bit_width for MLIR
update_bit_width_for_mlir(op_graph)
# Fill compilation artifacts
if compilation_artifacts is not None:
compilation_artifacts.operation_graph = op_graph

View File

@@ -0,0 +1,19 @@
"""Test converter functions"""
import pytest
from hdk.common.mlir.converters import add, mul, sub
class MockNode:
"""Mocking an intermediate node"""
def __init__(self, inputs=5, outputs=5):
self.inputs = [None for i in range(inputs)]
self.outputs = [None for i in range(outputs)]
@pytest.mark.parametrize("converter", [add, sub, mul])
def test_failing_converter(converter):
"""Test failing converter"""
with pytest.raises(TypeError, match=r"Don't support .* between .* and .*"):
converter(MockNode(2, 1), None, None, None)

View File

@@ -0,0 +1,171 @@
"""Test file for conversion to MLIR"""
# pylint: disable=no-name-in-module,no-member
import itertools
import pytest
from mlir.ir import IntegerType
from zamalang import compiler
from zamalang.dialects import hlfhe
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from hdk.hnumpy.compile import compile_numpy_function
def add(x, y):
"""Test simple add"""
return x + y
def sub(x, y):
"""Test simple sub"""
return x - y
def mul(x, y):
"""Test simple mul"""
return x * y
def sub_add_mul(x, y, z):
"""Test combination of ops"""
return z - y + x * z
def ret_multiple(x, y, z):
"""Test return of multiple values"""
return x, y, z
def ret_multiple_different_order(x, y, z):
"""Test return of multiple values in a different order from input"""
return y, z, x
def datagen(*args):
"""Generate data from ranges"""
for prod in itertools.product(*args):
yield prod
@pytest.mark.parametrize(
"func, args_dict, args_ranges",
[
(
add,
{
"x": EncryptedValue(Integer(64, is_signed=False)),
"y": ClearValue(Integer(32, is_signed=False)),
},
(range(0, 8), range(1, 4)),
),
(
add,
{
"x": ClearValue(Integer(32, is_signed=False)),
"y": EncryptedValue(Integer(64, is_signed=False)),
},
(range(0, 8), range(1, 4)),
),
(
add,
{
"x": EncryptedValue(Integer(7, is_signed=False)),
"y": EncryptedValue(Integer(7, is_signed=False)),
},
(range(7, 15), range(1, 5)),
),
(
sub,
{
"x": ClearValue(Integer(8, is_signed=False)),
"y": EncryptedValue(Integer(7, is_signed=False)),
},
(range(5, 10), range(2, 6)),
),
(
mul,
{
"x": EncryptedValue(Integer(7, is_signed=False)),
"y": ClearValue(Integer(8, is_signed=False)),
},
(range(1, 5), range(2, 8)),
),
(
mul,
{
"x": ClearValue(Integer(8, is_signed=False)),
"y": EncryptedValue(Integer(7, is_signed=False)),
},
(range(1, 5), range(2, 8)),
),
(
sub_add_mul,
{
"x": EncryptedValue(Integer(7, is_signed=False)),
"y": EncryptedValue(Integer(7, is_signed=False)),
"z": ClearValue(Integer(7, is_signed=False)),
},
(range(0, 8), range(1, 5), range(5, 12)),
),
(
ret_multiple,
{
"x": EncryptedValue(Integer(7, is_signed=False)),
"y": EncryptedValue(Integer(7, is_signed=False)),
"z": ClearValue(Integer(7, is_signed=False)),
},
(range(1, 5), range(1, 5), range(1, 5)),
),
(
ret_multiple_different_order,
{
"x": EncryptedValue(Integer(7, is_signed=False)),
"y": EncryptedValue(Integer(7, is_signed=False)),
"z": ClearValue(Integer(7, is_signed=False)),
},
(range(1, 5), range(1, 5), range(1, 5)),
),
],
)
def test_mlir_converter(func, args_dict, args_ranges):
"""Test the conversion to MLIR by calling the parser from the compiler"""
dataset = datagen(*args_ranges)
result_graph = compile_numpy_function(func, args_dict, dataset)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error
compiler.round_trip(mlir_result)
def test_hdk_encrypted_integer_to_mlir_type():
"""Test conversion of EncryptedValue into MLIR"""
value = EncryptedValue(Integer(7, is_signed=False))
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
eint = converter.hdk_value_to_mlir_type(value)
assert eint == hlfhe.EncryptedIntegerType.get(converter.context, 7)
@pytest.mark.parametrize("is_signed", [True, False])
def test_hdk_clear_integer_to_mlir_type(is_signed):
"""Test conversion of ClearValue into MLIR"""
value = ClearValue(Integer(5, is_signed=is_signed))
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
int_mlir = converter.hdk_value_to_mlir_type(value)
with converter.context:
if is_signed:
assert int_mlir == IntegerType.get_signed(5)
else:
assert int_mlir == IntegerType.get_unsigned(5)
def test_failing_hdk_to_mlir_type():
"""Test failing conversion of an unsupported type into MLIR"""
value = "random"
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
with pytest.raises(TypeError, match=r"can't convert value of type .* to MLIR type"):
converter.hdk_value_to_mlir_type(value)
# pylint: enable=no-name-in-module,no-member

View File

@@ -14,11 +14,11 @@ from hdk.hnumpy.compile import compile_numpy_function
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 42, ((-2, 2),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((-10, 2), (-4, 6)), ["x", "y"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 4)), ["x", "y"]),
pytest.param(
lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99),
((-1, 1), (3, 4), (10, 20)),
((4, 8), (3, 4), (0, 4)),
["x", "y", "z"],
),
],
@@ -80,3 +80,28 @@ def test_compile_function_with_direct_tlu_overflow():
{"x": EncryptedValue(Integer(3, is_signed=False))},
iter([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]),
)
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x - 10, ((-2, 2),), ["x"]),
],
)
def test_fail_compile(function, input_ranges, list_of_arg_names):
"""Test function compile_numpy_function for a program with signed values"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names
}
with pytest.raises(TypeError, match=r"signed integers aren't supported for MLIR lowering"):
compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)