mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat: adding a compilation api
also, showing data_types in get_printable_graph refs #86, #87
This commit is contained in:
committed by
Benoit Chevallier
parent
055298daf8
commit
6491e47178
@@ -219,11 +219,26 @@ def draw_graph(
|
||||
# pylint: enable=too-many-locals
|
||||
|
||||
|
||||
def get_printable_graph(opgraph: OPGraph) -> str:
|
||||
def data_type_to_string(node):
|
||||
"""Return the datatypes of the outputs of the node
|
||||
|
||||
Args:
|
||||
node: a graph node
|
||||
|
||||
Returns:
|
||||
str: a string representing the datatypes of the outputs of the node
|
||||
|
||||
"""
|
||||
return ", ".join([str(o.data_type) for o in node.outputs])
|
||||
|
||||
|
||||
def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
"""Return a string representing a graph
|
||||
|
||||
Args:
|
||||
graph (OPGraph): The graph that we want to draw
|
||||
show_data_types (bool): Whether or not showing data_types of nodes, eg
|
||||
to see their width
|
||||
|
||||
Returns:
|
||||
str: a string to print or save in a file
|
||||
@@ -265,7 +280,14 @@ def get_printable_graph(opgraph: OPGraph) -> str:
|
||||
list_of_arg_name.sort()
|
||||
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
|
||||
|
||||
returned_str += f"\n%{i} = {what_to_print}"
|
||||
new_line = f"%{i} = {what_to_print}"
|
||||
|
||||
# Manage datatypes
|
||||
if show_data_types:
|
||||
new_line = f"{new_line: <40s} # {data_type_to_string(node)}"
|
||||
|
||||
returned_str += f"\n{new_line}"
|
||||
|
||||
map_table[node] = i
|
||||
i += 1
|
||||
|
||||
|
||||
42
hdk/hnumpy/compile.py
Normal file
42
hdk/hnumpy/compile.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""hnumpy compilation function"""
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, Tuple
|
||||
|
||||
from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
|
||||
from hdk.hnumpy.tracing import trace_numpy_function
|
||||
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..hnumpy.tracing import trace_numpy_function
|
||||
|
||||
|
||||
def compile_numpy_function(
|
||||
function_to_trace: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
dataset: Iterator[Tuple[Any, ...]],
|
||||
) -> OPGraph:
|
||||
"""Main API of hnumpy, to be able to compile an homomorphic program
|
||||
|
||||
Args:
|
||||
function_to_trace (Callable): The function you want to trace
|
||||
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
||||
function is e.g. an EncryptedValue holding a 7bits unsigned Integer
|
||||
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
|
||||
Returns:
|
||||
OPGraph: currently returns a compilable graph, but later, it will return an MLIR compatible
|
||||
with the compiler, and even later, it will return the result of the compilation
|
||||
"""
|
||||
|
||||
# Trace
|
||||
op_graph = trace_numpy_function(function_to_trace, function_parameters)
|
||||
|
||||
# Find bounds with the dataset
|
||||
node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset)
|
||||
|
||||
# Update the graph accordingly: after that, we have the compilable graph
|
||||
op_graph.update_values_with_bounds(node_bounds)
|
||||
|
||||
return op_graph
|
||||
47
tests/hnumpy/test_compile.py
Normal file
47
tests/hnumpy/test_compile.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Test file for hnumpy compilation functions"""
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.common.debugging import draw_graph, get_printable_graph
|
||||
from hdk.hnumpy.compile import compile_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: x + 42, ((-2, 2),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((-10, 2), (-4, 6)), ["x", "y"]),
|
||||
pytest.param(lambda x, y: (x + 1, y + 10), ((-1, 1), (3, 4)), ["x", "y"]),
|
||||
pytest.param(
|
||||
lambda x, y, z: (x + y + 1 - z, x * y + 42, z, z + 99),
|
||||
((-1, 1), (3, 4), (10, 20)),
|
||||
["x", "y", "z"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
"""Test function compile_numpy_function for a program with multiple outputs"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedValue(Integer(64, True)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
op_graph = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
# TODO: For the moment, we don't have really checks, but some printfs. Later,
|
||||
# when we have the converter, we can check the MLIR
|
||||
draw_graph(op_graph, block_until_user_closes_graph=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
@@ -96,3 +96,46 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
print(f"\nExp {ref_graph_str}\n")
|
||||
|
||||
assert str_of_the_graph == ref_graph_str
|
||||
|
||||
|
||||
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
|
||||
# returning 23b), since they are replaced later by the real bitwidths computed on the
|
||||
# dataset
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,x_y,ref_graph_str",
|
||||
[
|
||||
(
|
||||
lambda x, y: x + y,
|
||||
(
|
||||
EncryptedValue(Integer(64, is_signed=False)),
|
||||
EncryptedValue(Integer(32, is_signed=True)),
|
||||
),
|
||||
"\n%0 = x # Integer<unsigned, 64 bits>"
|
||||
"\n%1 = y # Integer<signed, 32 bits>"
|
||||
"\n%2 = Add(0, 1) # Integer<signed, 65 bits>"
|
||||
"\nreturn(%2)",
|
||||
),
|
||||
(
|
||||
lambda x, y: x * y,
|
||||
(
|
||||
EncryptedValue(Integer(17, is_signed=False)),
|
||||
EncryptedValue(Integer(23, is_signed=False)),
|
||||
),
|
||||
"\n%0 = x # Integer<unsigned, 17 bits>"
|
||||
"\n%1 = y # Integer<unsigned, 23 bits>"
|
||||
"\n%2 = Mul(0, 1) # Integer<unsigned, 23 bits>"
|
||||
"\nreturn(%2)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
"Test hnumpy get_printable_graph with show_data_types"
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
|
||||
print(f"\nGot {str_of_the_graph}\n")
|
||||
print(f"\nExp {ref_graph_str}\n")
|
||||
|
||||
assert str_of_the_graph == ref_graph_str
|
||||
|
||||
Reference in New Issue
Block a user