mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
54 lines
1.5 KiB
Python
54 lines
1.5 KiB
Python
"""Test file for hnumpy debugging functions"""
|
|
|
|
import pytest
|
|
|
|
from hdk.common.data_types.integers import Integer
|
|
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
|
from hdk.common.debugging import draw_graph, get_printable_graph
|
|
from hdk.hnumpy import tracing
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"lambda_f,ref_graph_str",
|
|
[
|
|
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"),
|
|
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"),
|
|
(
|
|
lambda x, y: x + x - y * y * y + x,
|
|
"\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)"
|
|
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)",
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"x_y",
|
|
[
|
|
pytest.param(
|
|
(
|
|
EncryptedValue(Integer(64, is_signed=False)),
|
|
EncryptedValue(Integer(64, is_signed=False)),
|
|
),
|
|
id="Encrypted uint",
|
|
),
|
|
pytest.param(
|
|
(
|
|
EncryptedValue(Integer(64, is_signed=False)),
|
|
ClearValue(Integer(64, is_signed=False)),
|
|
),
|
|
id="Clear uint",
|
|
),
|
|
],
|
|
)
|
|
def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
|
"Test hnumpy get_printable_graph and draw_graph"
|
|
x, y = x_y
|
|
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
|
|
|
draw_graph(graph, block_until_user_closes_graph=False)
|
|
|
|
str_of_the_graph = get_printable_graph(graph)
|
|
|
|
print(f"\n{str_of_the_graph}\n")
|
|
|
|
assert str_of_the_graph == ref_graph_str
|