Files
concrete/tests/hnumpy/test_tracing.py
Arthur Meyre a060aaae99 feat(tracing): add tracing facilities
- add BaseTracer which will hold most of the boilerplate code
- add hnumpy with a bare NPTracer and tracing function
- update IR to be compatible with tracing helpers
- update test helper to properly check that graphs are equivalent
- add test tracing a simple addition
- rename common/data_types/helpers.py to .../dtypes_helpers.py to avoid
having too many files with the same name
- ignore missing type stubs in the default mypy command
- add a comfort Makefile target to get errors about missing mypy stubs
2021-07-26 17:05:53 +02:00

88 lines
2.3 KiB
Python

"""Test file for HDK's hnumpy tracing"""
import networkx as nx
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.representation import intermediate as ir
from hdk.hnumpy import tracing
@pytest.mark.parametrize(
"x",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
pytest.param(
EncryptedValue(Integer(64, is_signed=True)),
id="Encrypted int",
),
pytest.param(
ClearValue(Integer(64, is_signed=False)),
id="Clear uint",
),
pytest.param(
ClearValue(Integer(64, is_signed=True)),
id="Clear int",
),
],
)
@pytest.mark.parametrize(
"y",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
pytest.param(
EncryptedValue(Integer(64, is_signed=True)),
id="Encrypted int",
),
pytest.param(
ClearValue(Integer(64, is_signed=False)),
id="Clear uint",
),
pytest.param(
ClearValue(Integer(64, is_signed=True)),
id="Clear int",
),
],
)
def test_hnumpy_tracing_add(x, y, test_helpers):
"Test hnumpy tracing __add__"
def simple_add_function(x, y):
z = x + x
return z + y
graph = tracing.trace_numpy_function(simple_add_function, {"x": x, "y": y})
ref_graph = nx.MultiDiGraph()
input_x = ir.Input((x,))
input_y = ir.Input((y,))
add_node_z = ir.Add(
(
input_x.outputs[0],
input_x.outputs[0],
)
)
return_add_node = ir.Add(
(
add_node_z.outputs[0],
input_y.outputs[0],
)
)
ref_graph.add_node(input_x, content=input_x)
ref_graph.add_node(input_y, content=input_y)
ref_graph.add_node(add_node_z, content=add_node_z)
ref_graph.add_node(return_add_node, content=return_add_node)
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
ref_graph.add_edge(input_x, add_node_z, input_idx=1)
ref_graph.add_edge(add_node_z, return_add_node, input_idx=0)
ref_graph.add_edge(input_y, return_add_node, input_idx=1)
assert test_helpers.digraphs_are_equivalent(ref_graph, graph)