mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
204 lines
6.2 KiB
Python
204 lines
6.2 KiB
Python
"""Test file for hnumpy tracing"""
|
|
|
|
import networkx as nx
|
|
import numpy
|
|
import pytest
|
|
|
|
from hdk.common.data_types.floats import Float
|
|
from hdk.common.data_types.integers import Integer
|
|
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
|
from hdk.common.representation import intermediate as ir
|
|
from hdk.hnumpy import tracing
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"operation",
|
|
[ir.Add, ir.Sub, ir.Mul],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"x",
|
|
[
|
|
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="x: Encrypted uint"),
|
|
pytest.param(
|
|
EncryptedValue(Integer(64, is_signed=True)),
|
|
id="x: Encrypted int",
|
|
),
|
|
pytest.param(
|
|
ClearValue(Integer(64, is_signed=False)),
|
|
id="x: Clear uint",
|
|
),
|
|
pytest.param(
|
|
ClearValue(Integer(64, is_signed=True)),
|
|
id="x: Clear int",
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"y",
|
|
[
|
|
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="y: Encrypted uint"),
|
|
pytest.param(
|
|
EncryptedValue(Integer(64, is_signed=True)),
|
|
id="y: Encrypted int",
|
|
),
|
|
pytest.param(
|
|
ClearValue(Integer(64, is_signed=False)),
|
|
id="y: Clear uint",
|
|
),
|
|
pytest.param(
|
|
ClearValue(Integer(64, is_signed=True)),
|
|
id="y: Clear int",
|
|
),
|
|
],
|
|
)
|
|
def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
|
|
"Test hnumpy tracing a binary operation (in the supported ops)"
|
|
|
|
# Remark that the functions here have a common structure (which is
|
|
# 2x op y), such that creating further the ref_graph is easy, by
|
|
# hand
|
|
def simple_add_function(x, y):
|
|
z = x + x
|
|
return z + y
|
|
|
|
def simple_sub_function(x, y):
|
|
z = x + x
|
|
return z - y
|
|
|
|
def simple_mul_function(x, y):
|
|
z = x + x
|
|
return z * y
|
|
|
|
if operation == ir.Add:
|
|
function_to_compile = simple_add_function
|
|
elif operation == ir.Sub:
|
|
function_to_compile = simple_sub_function
|
|
elif operation == ir.Mul:
|
|
function_to_compile = simple_mul_function
|
|
else:
|
|
assert False, f"unknown operation {operation}"
|
|
|
|
op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y})
|
|
|
|
ref_graph = nx.MultiDiGraph()
|
|
|
|
input_x = ir.Input(x, input_name="x", program_input_idx=0)
|
|
input_y = ir.Input(y, input_name="y", program_input_idx=1)
|
|
|
|
add_node_z = ir.Add(
|
|
(
|
|
input_x.outputs[0],
|
|
input_x.outputs[0],
|
|
)
|
|
)
|
|
|
|
returned_final_node = operation(
|
|
(
|
|
add_node_z.outputs[0],
|
|
input_y.outputs[0],
|
|
)
|
|
)
|
|
|
|
ref_graph.add_node(input_x, content=input_x)
|
|
ref_graph.add_node(input_y, content=input_y)
|
|
ref_graph.add_node(add_node_z, content=add_node_z)
|
|
ref_graph.add_node(returned_final_node, content=returned_final_node)
|
|
|
|
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
|
|
ref_graph.add_edge(input_x, add_node_z, input_idx=1)
|
|
|
|
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0)
|
|
ref_graph.add_edge(input_y, returned_final_node, input_idx=1)
|
|
|
|
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
|
|
[
|
|
(
|
|
lambda x: x.astype(numpy.int32),
|
|
Integer(32, is_signed=True),
|
|
[
|
|
(14, numpy.int32(14)),
|
|
(1.5, numpy.int32(1)),
|
|
(2.0, numpy.int32(2)),
|
|
(-1.5, numpy.int32(-1)),
|
|
(2 ** 31 - 1, numpy.int32(2 ** 31 - 1)),
|
|
(-(2 ** 31), numpy.int32(-(2 ** 31))),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.uint32),
|
|
Integer(32, is_signed=False),
|
|
[
|
|
(14, numpy.uint32(14)),
|
|
(1.5, numpy.uint32(1)),
|
|
(2.0, numpy.uint32(2)),
|
|
(2 ** 32 - 1, numpy.uint32(2 ** 32 - 1)),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.int64),
|
|
Integer(64, is_signed=True),
|
|
[
|
|
(14, numpy.int64(14)),
|
|
(1.5, numpy.int64(1)),
|
|
(2.0, numpy.int64(2)),
|
|
(-1.5, numpy.int64(-1)),
|
|
(2 ** 63 - 1, numpy.int64(2 ** 63 - 1)),
|
|
(-(2 ** 63), numpy.int64(-(2 ** 63))),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.uint64),
|
|
Integer(64, is_signed=False),
|
|
[
|
|
(14, numpy.uint64(14)),
|
|
(1.5, numpy.uint64(1)),
|
|
(2.0, numpy.uint64(2)),
|
|
(2 ** 64 - 1, numpy.uint64(2 ** 64 - 1)),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.float64),
|
|
Float(64),
|
|
[
|
|
(14, numpy.float64(14.0)),
|
|
(1.5, numpy.float64(1.5)),
|
|
(2.0, numpy.float64(2.0)),
|
|
(-1.5, numpy.float64(-1.5)),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.float32),
|
|
Float(32),
|
|
[
|
|
(14, numpy.float32(14.0)),
|
|
(1.5, numpy.float32(1.5)),
|
|
(2.0, numpy.float32(2.0)),
|
|
(-1.5, numpy.float32(-1.5)),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_tracing_astype(
|
|
function_to_trace, op_graph_expected_output_type, input_and_expected_output_tuples
|
|
):
|
|
"""Test function for NPTracer.astype"""
|
|
for input_, expected_output in input_and_expected_output_tuples:
|
|
input_value = (
|
|
EncryptedValue(Integer(64, is_signed=True))
|
|
if isinstance(input_, int)
|
|
else EncryptedValue(Float(64))
|
|
)
|
|
|
|
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
|
output_node = op_graph.output_nodes[0]
|
|
assert op_graph_expected_output_type == output_node.outputs[0].data_type
|
|
|
|
node_results = op_graph.evaluate({0: numpy.array(input_)})
|
|
evaluated_output = node_results[output_node]
|
|
assert isinstance(evaluated_output, type(expected_output))
|
|
assert expected_output == evaluated_output
|