mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
- start tracing numpy.rint and manage dtypes - update BaseTracer to accept iterables as inputs, because NPTracer does not get list givent the way numpy sends arguments to the functions
282 lines
9.1 KiB
Python
282 lines
9.1 KiB
Python
"""Test file for hnumpy tracing"""
|
|
|
|
import networkx as nx
|
|
import numpy
|
|
import pytest
|
|
|
|
from hdk.common.data_types.floats import Float
|
|
from hdk.common.data_types.integers import Integer
|
|
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
|
from hdk.common.representation import intermediate as ir
|
|
from hdk.hnumpy import tracing
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"operation",
|
|
[ir.Add, ir.Sub, ir.Mul],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"x",
|
|
[
|
|
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="x: Encrypted uint"),
|
|
pytest.param(
|
|
EncryptedValue(Integer(64, is_signed=True)),
|
|
id="x: Encrypted int",
|
|
),
|
|
pytest.param(
|
|
ClearValue(Integer(64, is_signed=False)),
|
|
id="x: Clear uint",
|
|
),
|
|
pytest.param(
|
|
ClearValue(Integer(64, is_signed=True)),
|
|
id="x: Clear int",
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"y",
|
|
[
|
|
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="y: Encrypted uint"),
|
|
pytest.param(
|
|
EncryptedValue(Integer(64, is_signed=True)),
|
|
id="y: Encrypted int",
|
|
),
|
|
pytest.param(
|
|
ClearValue(Integer(64, is_signed=False)),
|
|
id="y: Clear uint",
|
|
),
|
|
pytest.param(
|
|
ClearValue(Integer(64, is_signed=True)),
|
|
id="y: Clear int",
|
|
),
|
|
],
|
|
)
|
|
def test_hnumpy_tracing_binary_op(operation, x, y, test_helpers):
|
|
"Test hnumpy tracing a binary operation (in the supported ops)"
|
|
|
|
# Remark that the functions here have a common structure (which is
|
|
# 2x op y), such that creating further the ref_graph is easy, by
|
|
# hand
|
|
def simple_add_function(x, y):
|
|
z = x + x
|
|
return z + y
|
|
|
|
def simple_sub_function(x, y):
|
|
z = x + x
|
|
return z - y
|
|
|
|
def simple_mul_function(x, y):
|
|
z = x + x
|
|
return z * y
|
|
|
|
if operation == ir.Add:
|
|
function_to_compile = simple_add_function
|
|
elif operation == ir.Sub:
|
|
function_to_compile = simple_sub_function
|
|
elif operation == ir.Mul:
|
|
function_to_compile = simple_mul_function
|
|
else:
|
|
assert False, f"unknown operation {operation}"
|
|
|
|
op_graph = tracing.trace_numpy_function(function_to_compile, {"x": x, "y": y})
|
|
|
|
ref_graph = nx.MultiDiGraph()
|
|
|
|
input_x = ir.Input(x, input_name="x", program_input_idx=0)
|
|
input_y = ir.Input(y, input_name="y", program_input_idx=1)
|
|
|
|
add_node_z = ir.Add(
|
|
(
|
|
input_x.outputs[0],
|
|
input_x.outputs[0],
|
|
)
|
|
)
|
|
|
|
returned_final_node = operation(
|
|
(
|
|
add_node_z.outputs[0],
|
|
input_y.outputs[0],
|
|
)
|
|
)
|
|
|
|
ref_graph.add_node(input_x, content=input_x)
|
|
ref_graph.add_node(input_y, content=input_y)
|
|
ref_graph.add_node(add_node_z, content=add_node_z)
|
|
ref_graph.add_node(returned_final_node, content=returned_final_node)
|
|
|
|
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
|
|
ref_graph.add_edge(input_x, add_node_z, input_idx=1)
|
|
|
|
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0)
|
|
ref_graph.add_edge(input_y, returned_final_node, input_idx=1)
|
|
|
|
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
|
|
[
|
|
(
|
|
lambda x: x.astype(numpy.int32),
|
|
Integer(32, is_signed=True),
|
|
[
|
|
(14, numpy.int32(14)),
|
|
(1.5, numpy.int32(1)),
|
|
(2.0, numpy.int32(2)),
|
|
(-1.5, numpy.int32(-1)),
|
|
(2 ** 31 - 1, numpy.int32(2 ** 31 - 1)),
|
|
(-(2 ** 31), numpy.int32(-(2 ** 31))),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.uint32),
|
|
Integer(32, is_signed=False),
|
|
[
|
|
(14, numpy.uint32(14)),
|
|
(1.5, numpy.uint32(1)),
|
|
(2.0, numpy.uint32(2)),
|
|
(2 ** 32 - 1, numpy.uint32(2 ** 32 - 1)),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.int64),
|
|
Integer(64, is_signed=True),
|
|
[
|
|
(14, numpy.int64(14)),
|
|
(1.5, numpy.int64(1)),
|
|
(2.0, numpy.int64(2)),
|
|
(-1.5, numpy.int64(-1)),
|
|
(2 ** 63 - 1, numpy.int64(2 ** 63 - 1)),
|
|
(-(2 ** 63), numpy.int64(-(2 ** 63))),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.uint64),
|
|
Integer(64, is_signed=False),
|
|
[
|
|
(14, numpy.uint64(14)),
|
|
(1.5, numpy.uint64(1)),
|
|
(2.0, numpy.uint64(2)),
|
|
(2 ** 64 - 1, numpy.uint64(2 ** 64 - 1)),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.float64),
|
|
Float(64),
|
|
[
|
|
(14, numpy.float64(14.0)),
|
|
(1.5, numpy.float64(1.5)),
|
|
(2.0, numpy.float64(2.0)),
|
|
(-1.5, numpy.float64(-1.5)),
|
|
],
|
|
),
|
|
(
|
|
lambda x: x.astype(numpy.float32),
|
|
Float(32),
|
|
[
|
|
(14, numpy.float32(14.0)),
|
|
(1.5, numpy.float32(1.5)),
|
|
(2.0, numpy.float32(2.0)),
|
|
(-1.5, numpy.float32(-1.5)),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_tracing_astype(
|
|
function_to_trace, op_graph_expected_output_type, input_and_expected_output_tuples
|
|
):
|
|
"""Test function for NPTracer.astype"""
|
|
for input_, expected_output in input_and_expected_output_tuples:
|
|
input_value = (
|
|
EncryptedValue(Integer(64, is_signed=True))
|
|
if isinstance(input_, int)
|
|
else EncryptedValue(Float(64))
|
|
)
|
|
|
|
op_graph = tracing.trace_numpy_function(function_to_trace, {"x": input_value})
|
|
output_node = op_graph.output_nodes[0]
|
|
assert op_graph_expected_output_type == output_node.outputs[0].data_type
|
|
|
|
node_results = op_graph.evaluate({0: numpy.array(input_)})
|
|
evaluated_output = node_results[output_node]
|
|
assert isinstance(evaluated_output, type(expected_output))
|
|
assert expected_output == evaluated_output
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"function_to_trace,inputs,expected_output_node,expected_output_value",
|
|
[
|
|
# We cannot call trace_numpy_function on some numpy function as getting the signature for
|
|
# these functions fails, so we wrap it in a lambda
|
|
# pylint: disable=unnecessary-lambda
|
|
pytest.param(
|
|
lambda x: numpy.rint(x),
|
|
{"x": EncryptedValue(Integer(7, is_signed=False))},
|
|
ir.ArbitraryFunction,
|
|
EncryptedValue(Float(64)),
|
|
),
|
|
pytest.param(
|
|
lambda x: numpy.rint(x),
|
|
{"x": EncryptedValue(Integer(32, is_signed=True))},
|
|
ir.ArbitraryFunction,
|
|
EncryptedValue(Float(64)),
|
|
),
|
|
pytest.param(
|
|
lambda x: numpy.rint(x),
|
|
{"x": EncryptedValue(Integer(64, is_signed=True))},
|
|
ir.ArbitraryFunction,
|
|
EncryptedValue(Float(64)),
|
|
),
|
|
pytest.param(
|
|
lambda x: numpy.rint(x),
|
|
{"x": EncryptedValue(Integer(128, is_signed=True))},
|
|
ir.ArbitraryFunction,
|
|
None,
|
|
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
|
),
|
|
pytest.param(
|
|
lambda x: numpy.rint(x),
|
|
{"x": EncryptedValue(Float(64))},
|
|
ir.ArbitraryFunction,
|
|
EncryptedValue(Float(64)),
|
|
),
|
|
# The next test case is only for coverage purposes, to trigger the unsupported method
|
|
# exception handling
|
|
pytest.param(
|
|
lambda x: numpy.add.reduce(x),
|
|
{"x": EncryptedValue(Integer(32, is_signed=True))},
|
|
None,
|
|
None,
|
|
marks=pytest.mark.xfail(strict=True, raises=NotImplementedError),
|
|
),
|
|
# pylint: enable=unnecessary-lambda
|
|
],
|
|
)
|
|
def test_trace_hnumpy_supported_ufuncs(
|
|
function_to_trace, inputs, expected_output_node, expected_output_value
|
|
):
|
|
"""Function to trace supported numpy ufuncs"""
|
|
op_graph = tracing.trace_numpy_function(function_to_trace, inputs)
|
|
|
|
assert len(op_graph.output_nodes) == 1
|
|
assert isinstance(op_graph.output_nodes[0], expected_output_node)
|
|
assert len(op_graph.output_nodes[0].outputs) == 1
|
|
assert op_graph.output_nodes[0].outputs[0] == expected_output_value
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"np_ufunc,expected_tracing_func",
|
|
[
|
|
pytest.param(numpy.rint, tracing.NPTracer.rint),
|
|
# There is a need to test the case where the function fails, I chose numpy.conjugate which
|
|
# works on complex types, as we don't talk about complex types for now this looks like a
|
|
# good long term candidate to check for an unsupported function
|
|
pytest.param(
|
|
numpy.conjugate, None, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError)
|
|
),
|
|
],
|
|
)
|
|
def test_nptracer_get_tracing_func_for_np_ufunc(np_ufunc, expected_tracing_func):
|
|
"""Test NPTracer get_tracing_func_for_np_ufunc"""
|
|
assert tracing.NPTracer.get_tracing_func_for_np_ufunc(np_ufunc) == expected_tracing_func
|