Files
concrete/tests/hnumpy/test_compile.py
Benoit Chevallier-Mames 6491e47178 feat: adding a compilation api
also, showing data_types in get_printable_graph
refs #86, #87
2021-08-06 17:04:30 +02:00

48 lines
1.6 KiB
Python

"""Test file for hnumpy compilation functions"""
import itertools
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.hnumpy.compile import compile_numpy_function
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 42, ((-2, 2),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((-10, 2), (-4, 6)), ["x", "y"]),
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 4)), ["x", "y"]),
pytest.param(
lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99),
((-1, 1), (3, 4), (10, 20)),
["x", "y", "z"],
),
],
)
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
"""Test function compile_numpy_function for a program with multiple outputs"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names
}
op_graph = compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)
# TODO: For the moment, we don't have really checks, but some printfs. Later,
# when we have the converter, we can check the MLIR
draw_graph(op_graph, block_until_user_closes_graph=False)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")