diff --git a/hdk/common/debugging/draw_graph.py b/hdk/common/debugging/draw_graph.py index 9d404c94c..ec6a8038c 100644 --- a/hdk/common/debugging/draw_graph.py +++ b/hdk/common/debugging/draw_graph.py @@ -219,11 +219,26 @@ def draw_graph( # pylint: enable=too-many-locals -def get_printable_graph(opgraph: OPGraph) -> str: +def data_type_to_string(node): + """Return the datatypes of the outputs of the node + + Args: + node: a graph node + + Returns: + str: a string representing the datatypes of the outputs of the node + + """ + return ", ".join([str(o.data_type) for o in node.outputs]) + + +def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: """Return a string representing a graph Args: graph (OPGraph): The graph that we want to draw + show_data_types (bool): Whether or not showing data_types of nodes, eg + to see their width Returns: str: a string to print or save in a file @@ -265,7 +280,14 @@ def get_printable_graph(opgraph: OPGraph) -> str: list_of_arg_name.sort() what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")" - returned_str += f"\n%{i} = {what_to_print}" + new_line = f"%{i} = {what_to_print}" + + # Manage datatypes + if show_data_types: + new_line = f"{new_line: <40s} # {data_type_to_string(node)}" + + returned_str += f"\n{new_line}" + map_table[node] = i i += 1 diff --git a/hdk/hnumpy/compile.py b/hdk/hnumpy/compile.py new file mode 100644 index 000000000..0882aa6bd --- /dev/null +++ b/hdk/hnumpy/compile.py @@ -0,0 +1,42 @@ +"""hnumpy compilation function""" + +from typing import Any, Callable, Dict, Iterator, Tuple + +from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset +from hdk.hnumpy.tracing import trace_numpy_function + +from ..common.data_types import BaseValue +from ..common.operator_graph import OPGraph +from ..hnumpy.tracing import trace_numpy_function + + +def compile_numpy_function( + function_to_trace: Callable, + function_parameters: Dict[str, BaseValue], + dataset: Iterator[Tuple[Any, ...]], +) -> OPGraph: + """Main API of hnumpy, to be able to compile an homomorphic program + + Args: + function_to_trace (Callable): The function you want to trace + function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the + function is e.g. an EncryptedValue holding a 7bits unsigned Integer + dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It + needs to be an iterator on tuples which are of the same length than the number of + parameters in the function, and in the same order than these same parameters + + Returns: + OPGraph: currently returns a compilable graph, but later, it will return an MLIR compatible + with the compiler, and even later, it will return the result of the compilation + """ + + # Trace + op_graph = trace_numpy_function(function_to_trace, function_parameters) + + # Find bounds with the dataset + node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset) + + # Update the graph accordingly: after that, we have the compilable graph + op_graph.update_values_with_bounds(node_bounds) + + return op_graph diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py new file mode 100644 index 000000000..fb7a76147 --- /dev/null +++ b/tests/hnumpy/test_compile.py @@ -0,0 +1,47 @@ +"""Test file for hnumpy compilation functions""" +import itertools + +import pytest + +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import EncryptedValue +from hdk.common.debugging import draw_graph, get_printable_graph +from hdk.hnumpy.compile import compile_numpy_function + + +@pytest.mark.parametrize( + "function,input_ranges,list_of_arg_names", + [ + pytest.param(lambda x: x + 42, ((-2, 2),), ["x"]), + pytest.param(lambda x, y: x + y + 8, ((-10, 2), (-4, 6)), ["x", "y"]), + pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 4)), ["x", "y"]), + pytest.param( + lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99), + ((-1, 1), (3, 4), (10, 20)), + ["x", "y", "z"], + ), + ], +) +def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names): + """Test function compile_numpy_function for a program with multiple outputs""" + + def data_gen(args): + for prod in itertools.product(*args): + yield prod + + function_parameters = { + arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names + } + + op_graph = compile_numpy_function( + function, + function_parameters, + data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + ) + + # TODO: For the moment, we don't have really checks, but some printfs. Later, + # when we have the converter, we can check the MLIR + draw_graph(op_graph, block_until_user_closes_graph=False) + + str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) + print(f"\n{str_of_the_graph}\n") diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index ea611a791..750f56bbd 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -96,3 +96,46 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y): print(f"\nExp {ref_graph_str}\n") assert str_of_the_graph == ref_graph_str + + +# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b +# returning 23b), since they are replaced later by the real bitwidths computed on the +# dataset +@pytest.mark.parametrize( + "lambda_f,x_y,ref_graph_str", + [ + ( + lambda x, y: x + y, + ( + EncryptedValue(Integer(64, is_signed=False)), + EncryptedValue(Integer(32, is_signed=True)), + ), + "\n%0 = x # Integer" + "\n%1 = y # Integer" + "\n%2 = Add(0, 1) # Integer" + "\nreturn(%2)", + ), + ( + lambda x, y: x * y, + ( + EncryptedValue(Integer(17, is_signed=False)), + EncryptedValue(Integer(23, is_signed=False)), + ), + "\n%0 = x # Integer" + "\n%1 = y # Integer" + "\n%2 = Mul(0, 1) # Integer" + "\nreturn(%2)", + ), + ], +) +def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str): + "Test hnumpy get_printable_graph with show_data_types" + x, y = x_y + graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y}) + + str_of_the_graph = get_printable_graph(graph, show_data_types=True) + + print(f"\nGot {str_of_the_graph}\n") + print(f"\nExp {ref_graph_str}\n") + + assert str_of_the_graph == ref_graph_str