mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
committed by
Benoit Chevallier
parent
1196b00c6b
commit
c91ee858c5
@@ -1,2 +1,2 @@
|
||||
"""HDK's module for debugging"""
|
||||
from .draw_graph import draw_graph
|
||||
from .draw_graph import draw_graph, get_printable_graph
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""functions to draw the different graphs we can generate in the package, eg to debug"""
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
@@ -194,3 +194,52 @@ def draw_graph(
|
||||
plt.show(block=block_until_user_closes_graph)
|
||||
|
||||
# pylint: enable=too-many-locals
|
||||
|
||||
|
||||
def get_printable_graph(graph: nx.DiGraph) -> str:
|
||||
"""
|
||||
Return a string representing a graph
|
||||
|
||||
Args:
|
||||
graph (nx.DiGraph): The graph that we want to draw
|
||||
|
||||
Returns:
|
||||
a string to print or save in a file
|
||||
|
||||
"""
|
||||
returned_str = ""
|
||||
|
||||
i = 0
|
||||
map_table: Dict[Any, int] = {}
|
||||
|
||||
for node in nx.topological_sort(graph):
|
||||
|
||||
if not isinstance(node, ir.Input):
|
||||
what_to_print = node.__class__.__name__ + "("
|
||||
|
||||
# Find all the names of the current predecessors of the node
|
||||
list_of_arg_name = []
|
||||
|
||||
for pred, index_list in graph.pred[node].items():
|
||||
for index in index_list.values():
|
||||
# Remark that we keep the index of the predecessor and its
|
||||
# name, to print sources in the right order, which is
|
||||
# important for eg non commutative operations
|
||||
list_of_arg_name += [(index["input_idx"], str(map_table[pred]))]
|
||||
|
||||
# Some checks, because the previous algorithm is not clear
|
||||
assert len(list_of_arg_name) == len({x[0] for x in list_of_arg_name})
|
||||
assert [x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))
|
||||
|
||||
# Then, just print the predecessors in the right order
|
||||
list_of_arg_name.sort()
|
||||
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
|
||||
|
||||
else:
|
||||
what_to_print = node.input_name
|
||||
|
||||
returned_str += f"\n%{i} = {what_to_print}"
|
||||
map_table[node] = i
|
||||
i += 1
|
||||
|
||||
return returned_str
|
||||
|
||||
@@ -4,55 +4,50 @@ import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
||||
from hdk.common.debugging import draw_graph
|
||||
from hdk.common.debugging import draw_graph, get_printable_graph
|
||||
from hdk.hnumpy import tracing
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f",
|
||||
"lambda_f,ref_graph_str",
|
||||
[
|
||||
lambda x, y: x + y,
|
||||
lambda x, y: x + x - y * y * y + x,
|
||||
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"),
|
||||
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"),
|
||||
(
|
||||
lambda x, y: x + x - y * y * y + x,
|
||||
"\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)"
|
||||
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"x",
|
||||
"x_y",
|
||||
[
|
||||
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
|
||||
# pytest.param(
|
||||
# EncryptedValue(Integer(64, is_signed=True)),
|
||||
# id="Encrypted int",
|
||||
# ),
|
||||
# pytest.param(
|
||||
# ClearValue(Integer(64, is_signed=False)),
|
||||
# id="Clear uint",
|
||||
# ),
|
||||
# pytest.param(
|
||||
# ClearValue(Integer(64, is_signed=True)),
|
||||
# id="Clear int",
|
||||
# ),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"y",
|
||||
[
|
||||
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
|
||||
# pytest.param(
|
||||
# EncryptedValue(Integer(64, is_signed=True)),
|
||||
# id="Encrypted int",
|
||||
# ),
|
||||
pytest.param(
|
||||
ClearValue(Integer(64, is_signed=False)),
|
||||
(
|
||||
EncryptedValue(Integer(64, is_signed=False)),
|
||||
EncryptedValue(Integer(64, is_signed=False)),
|
||||
),
|
||||
id="Encrypted uint",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
EncryptedValue(Integer(64, is_signed=False)),
|
||||
ClearValue(Integer(64, is_signed=False)),
|
||||
),
|
||||
id="Clear uint",
|
||||
),
|
||||
# pytest.param(
|
||||
# ClearValue(Integer(64, is_signed=True)),
|
||||
# id="Clear int",
|
||||
# ),
|
||||
],
|
||||
)
|
||||
def test_hnumpy_draw_graph(lambda_f, x, y):
|
||||
"Test hnumpy draw_graph"
|
||||
def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
"Test hnumpy get_printable_graph and draw_graph"
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
draw_graph(graph, block_until_user_closes_graph=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
assert str_of_the_graph == ref_graph_str
|
||||
|
||||
Reference in New Issue
Block a user