Files
concrete/tests/hnumpy/test_tracing.py

112 lines
3.1 KiB
Python

"""Test file for hnumpy tracing"""
import networkx as nx
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.representation import intermediate as ir
from hdk.hnumpy import tracing
@pytest.mark.parametrize(
"operation",
[ir.Add, ir.Sub, ir.Mul],
)
@pytest.mark.parametrize(
"x",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="x: Encrypted uint"),
pytest.param(
EncryptedValue(Integer(64, is_signed=True)),
id="x: Encrypted int",
),
pytest.param(
ClearValue(Integer(64, is_signed=False)),
id="x: Clear uint",
),
pytest.param(
ClearValue(Integer(64, is_signed=True)),
id="x: Clear int",
),
],
)
@pytest.mark.parametrize(
"y",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="y: Encrypted uint"),
pytest.param(
EncryptedValue(Integer(64, is_signed=True)),
id="y: Encrypted int",
),
pytest.param(
ClearValue(Integer(64, is_signed=False)),
id="y: Clear uint",
),
pytest.param(
ClearValue(Integer(64, is_signed=True)),
id="y: Clear int",
),
],
)
def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
"Test hnumpy tracing a binary operation (in the supported ops)"
# Remark that the functions here have a common structure (which is
# 2x op y), such that creating further the ref_graph is easy, by
# hand
def simple_add_function(x, y):
z = x + x
return z + y
def simple_sub_function(x, y):
z = x + x
return z - y
def simple_mul_function(x, y):
z = x + x
return z * y
if operation == ir.Add:
function_to_compile = simple_add_function
elif operation == ir.Sub:
function_to_compile = simple_sub_function
elif operation == ir.Mul:
function_to_compile = simple_mul_function
else:
assert False, f"unknown operation {operation}"
op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y})
ref_graph = nx.MultiDiGraph()
input_x = ir.Input(x, input_name="x", program_input_idx=0)
input_y = ir.Input(y, input_name="y", program_input_idx=1)
add_node_z = ir.Add(
(
input_x.outputs[0],
input_x.outputs[0],
)
)
returned_final_node = operation(
(
add_node_z.outputs[0],
input_y.outputs[0],
)
)
ref_graph.add_node(input_x, content=input_x)
ref_graph.add_node(input_y, content=input_y)
ref_graph.add_node(add_node_z, content=add_node_z)
ref_graph.add_node(returned_final_node, content=returned_final_node)
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
ref_graph.add_edge(input_x, add_node_z, input_idx=1)
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0)
ref_graph.add_edge(input_y, returned_final_node, input_idx=1)
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)