Files
concrete/tests/hnumpy/test_compile.py
Ayoub Benaissa f6c9618b5a feat(mlir): MLIR Conversion (#103)
* feat(mlir): conversion from HDKIR to MLIR

* feat(mlir): support ir.Sub and ir.Mul

- better type conversion from HDK to MLIR
- Context management inside the converter class
- better handling of input type in conversion functions

* refactor(mlir): use input and output from OPGraph

Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai>

* feat(mlir): eint-int subtractions

* feat(mlir): adhere to spec for supported ops

* feat(OPGraph): getters for ordered inputs/outputs

+ formatting

* tests(mlir): test converion via compiler roundtrip

* fix(mlir): flip operands on int_eint sym ops

* feat(mlir): check that the outputs are unsigned

* feat(mlir): set bit_width of all nodes to the max

This is currently required as the compiler is already assuming this.
Could be removed from HDK when the compiler can do it on its own

* feat: value_is_integer + CRs

disable some linting errors

* tests: update compile tests + coverage

* refactor: reorganize mlir package + better doc

* doc: conformance with pydocstyle

Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai>
2021-08-13 12:50:31 +01:00

108 lines
3.3 KiB
Python

"""Test file for hnumpy compilation functions"""
import itertools
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.common.extensions.table import LookupTable
from hdk.hnumpy.compile import compile_numpy_function
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 42, ((-2, 2),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 4)), ["x", "y"]),
pytest.param(
lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99),
((4, 8), (3, 4), (0, 4)),
["x", "y", "z"],
),
],
)
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
"""Test function compile_numpy_function for a program with multiple outputs"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names
}
op_graph = compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)
# TODO: For the moment, we don't have really checks, but some printfs. Later,
# when we have the converter, we can check the MLIR
draw_graph(op_graph, block_until_user_closes_graph=False)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")
def test_compile_function_with_direct_tlu():
"""Test compile_numpy_function for a program with direct table lookup"""
table = LookupTable([9, 2, 4, 11])
def function(x):
return x + table[x]
op_graph = compile_numpy_function(
function,
{"x": EncryptedValue(Integer(2, is_signed=False))},
iter([(0,), (1,), (2,), (3,)]),
)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")
def test_compile_function_with_direct_tlu_overflow():
"""Test compile_numpy_function for a program with direct table lookup overflow"""
table = LookupTable([9, 2, 4, 11])
def function(x):
return table[x]
with pytest.raises(ValueError):
compile_numpy_function(
function,
{"x": EncryptedValue(Integer(3, is_signed=False))},
iter([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]),
)
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x - 10, ((-2, 2),), ["x"]),
],
)
def test_fail_compile(function, input_ranges, list_of_arg_names):
"""Test function compile_numpy_function for a program with signed values"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names
}
with pytest.raises(TypeError, match=r"signed integers aren't supported for MLIR lowering"):
compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)