feat: adding a compilation api

also, showing data_types in get_printable_graph
refs #86, #87
This commit is contained in:
Benoit Chevallier-Mames
2021-08-05 17:36:25 +02:00
committed by Benoit Chevallier
parent 055298daf8
commit 6491e47178
4 changed files with 156 additions and 2 deletions

View File

@@ -219,11 +219,26 @@ def draw_graph(
# pylint: enable=too-many-locals
def get_printable_graph(opgraph: OPGraph) -> str:
def data_type_to_string(node):
"""Return the datatypes of the outputs of the node
Args:
node: a graph node
Returns:
str: a string representing the datatypes of the outputs of the node
"""
return ", ".join([str(o.data_type) for o in node.outputs])
def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
"""Return a string representing a graph
Args:
graph (OPGraph): The graph that we want to draw
show_data_types (bool): Whether or not showing data_types of nodes, eg
to see their width
Returns:
str: a string to print or save in a file
@@ -265,7 +280,14 @@ def get_printable_graph(opgraph: OPGraph) -> str:
list_of_arg_name.sort()
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
returned_str += f"\n%{i} = {what_to_print}"
new_line = f"%{i} = {what_to_print}"
# Manage datatypes
if show_data_types:
new_line = f"{new_line: <40s} # {data_type_to_string(node)}"
returned_str += f"\n{new_line}"
map_table[node] = i
i += 1

42
hdk/hnumpy/compile.py Normal file
View File

@@ -0,0 +1,42 @@
"""hnumpy compilation function"""
from typing import Any, Callable, Dict, Iterator, Tuple
from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
from hdk.hnumpy.tracing import trace_numpy_function
from ..common.data_types import BaseValue
from ..common.operator_graph import OPGraph
from ..hnumpy.tracing import trace_numpy_function
def compile_numpy_function(
function_to_trace: Callable,
function_parameters: Dict[str, BaseValue],
dataset: Iterator[Tuple[Any, ...]],
) -> OPGraph:
"""Main API of hnumpy, to be able to compile an homomorphic program
Args:
function_to_trace (Callable): The function you want to trace
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
function is e.g. an EncryptedValue holding a 7bits unsigned Integer
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
needs to be an iterator on tuples which are of the same length than the number of
parameters in the function, and in the same order than these same parameters
Returns:
OPGraph: currently returns a compilable graph, but later, it will return an MLIR compatible
with the compiler, and even later, it will return the result of the compilation
"""
# Trace
op_graph = trace_numpy_function(function_to_trace, function_parameters)
# Find bounds with the dataset
node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset)
# Update the graph accordingly: after that, we have the compilable graph
op_graph.update_values_with_bounds(node_bounds)
return op_graph

View File

@@ -0,0 +1,47 @@
"""Test file for hnumpy compilation functions"""
import itertools
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.hnumpy.compile import compile_numpy_function
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 42, ((-2, 2),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((-10, 2), (-4, 6)), ["x", "y"]),
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 4)), ["x", "y"]),
pytest.param(
lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99),
((-1, 1), (3, 4), (10, 20)),
["x", "y", "z"],
),
],
)
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
"""Test function compile_numpy_function for a program with multiple outputs"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names
}
op_graph = compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)
# TODO: For the moment, we don't have really checks, but some printfs. Later,
# when we have the converter, we can check the MLIR
draw_graph(op_graph, block_until_user_closes_graph=False)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")

View File

@@ -96,3 +96,46 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
print(f"\nExp {ref_graph_str}\n")
assert str_of_the_graph == ref_graph_str
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
# returning 23b), since they are replaced later by the real bitwidths computed on the
# dataset
@pytest.mark.parametrize(
"lambda_f,x_y,ref_graph_str",
[
(
lambda x, y: x + y,
(
EncryptedValue(Integer(64, is_signed=False)),
EncryptedValue(Integer(32, is_signed=True)),
),
"\n%0 = x # Integer<unsigned, 64 bits>"
"\n%1 = y # Integer<signed, 32 bits>"
"\n%2 = Add(0, 1) # Integer<signed, 65 bits>"
"\nreturn(%2)",
),
(
lambda x, y: x * y,
(
EncryptedValue(Integer(17, is_signed=False)),
EncryptedValue(Integer(23, is_signed=False)),
),
"\n%0 = x # Integer<unsigned, 17 bits>"
"\n%1 = y # Integer<unsigned, 23 bits>"
"\n%2 = Mul(0, 1) # Integer<unsigned, 23 bits>"
"\nreturn(%2)",
),
],
)
def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
"Test hnumpy get_printable_graph with show_data_types"
x, y = x_y
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
print(f"\nGot {str_of_the_graph}\n")
print(f"\nExp {ref_graph_str}\n")
assert str_of_the_graph == ref_graph_str