mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
264 lines
11 KiB
Python
264 lines
11 KiB
Python
"""Test file for user-friendly numpy compilation functions"""
|
|
|
|
import numpy
|
|
import pytest
|
|
|
|
from concrete.common.debugging import format_operation_graph
|
|
from concrete.numpy.np_fhe_compiler import NPFHECompiler
|
|
|
|
|
|
def complicated_topology(x, y):
|
|
"""Mix x in an intricated way."""
|
|
intermediate = x + y
|
|
x_p_1 = intermediate + 1
|
|
x_p_2 = intermediate + 2
|
|
x_p_3 = x_p_1 + x_p_2
|
|
return (
|
|
x_p_3.astype(numpy.int32),
|
|
x_p_2.astype(numpy.int32),
|
|
(x_p_2 + 3).astype(numpy.int32),
|
|
x_p_3.astype(numpy.int32) + 67,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)])
|
|
def test_np_fhe_compiler_op_graph(
|
|
input_shape, default_compilation_configuration, check_array_equality
|
|
):
|
|
"""Test NPFHECompiler in two subtests."""
|
|
subtest_np_fhe_compiler_1_input_op_graph(
|
|
input_shape, default_compilation_configuration, check_array_equality
|
|
)
|
|
subtest_np_fhe_compiler_2_inputs_op_graph(
|
|
input_shape, default_compilation_configuration, check_array_equality
|
|
)
|
|
|
|
|
|
def subtest_np_fhe_compiler_1_input_op_graph(
|
|
input_shape, default_compilation_configuration, check_array_equality
|
|
):
|
|
"""test for NPFHECompiler on one input function"""
|
|
|
|
def function_to_compile(x):
|
|
return complicated_topology(x, 0)
|
|
|
|
compiler = NPFHECompiler(
|
|
function_to_compile,
|
|
{"x": "encrypted"},
|
|
default_compilation_configuration,
|
|
)
|
|
|
|
# For coverage when the OPGraph is not yet traced
|
|
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
|
|
|
|
assert compiler.compilation_configuration == default_compilation_configuration
|
|
assert compiler.compilation_configuration is not default_compilation_configuration
|
|
|
|
for i in numpy.arange(5):
|
|
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
|
check_array_equality(compiler(i), function_to_compile(i))
|
|
|
|
# For coverage, check that we flush the inputset when we query the OPGraph
|
|
current_op_graph = compiler.op_graph
|
|
assert current_op_graph is not compiler.op_graph
|
|
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
|
|
# For coverage, cover case where the current inputset is empty
|
|
compiler._eval_on_current_inputset() # pylint: disable=protected-access
|
|
|
|
# Continue a bit more
|
|
for i in numpy.arange(5, 10):
|
|
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
|
check_array_equality(compiler(i), function_to_compile(i))
|
|
|
|
if input_shape == ():
|
|
assert (
|
|
(got := format_operation_graph(compiler.op_graph))
|
|
== """ %0 = 67 # ClearScalar<uint7>
|
|
%1 = 2 # ClearScalar<uint2>
|
|
%2 = 3 # ClearScalar<uint2>
|
|
%3 = 1 # ClearScalar<uint1>
|
|
%4 = x # EncryptedScalar<uint4>
|
|
%5 = 0 # ClearScalar<uint1>
|
|
%6 = add(%4, %5) # EncryptedScalar<uint4>
|
|
%7 = add(%6, %1) # EncryptedScalar<uint4>
|
|
%8 = add(%6, %3) # EncryptedScalar<uint4>
|
|
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
|
|
%10 = add(%7, %2) # EncryptedScalar<uint4>
|
|
%11 = add(%8, %7) # EncryptedScalar<uint5>
|
|
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint4>
|
|
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
|
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
|
%15 = add(%14, %0) # EncryptedScalar<uint7>
|
|
(%13, %9, %12, %15)"""
|
|
), got
|
|
else:
|
|
assert (
|
|
(got := format_operation_graph(compiler.op_graph))
|
|
== """ %0 = 67 # ClearScalar<uint7>
|
|
%1 = 2 # ClearScalar<uint2>
|
|
%2 = 3 # ClearScalar<uint2>
|
|
%3 = 1 # ClearScalar<uint1>
|
|
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%5 = 0 # ClearScalar<uint1>
|
|
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%10 = add(%7, %2) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
|
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
|
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
|
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
|
|
(%13, %9, %12, %15)"""
|
|
), got
|
|
|
|
|
|
def subtest_np_fhe_compiler_2_inputs_op_graph(
|
|
input_shape, default_compilation_configuration, check_array_equality
|
|
):
|
|
"""test for NPFHECompiler on two inputs function"""
|
|
|
|
compiler = NPFHECompiler(
|
|
complicated_topology,
|
|
{"x": "encrypted", "y": "clear"},
|
|
default_compilation_configuration,
|
|
)
|
|
|
|
# For coverage when the OPGraph is not yet traced
|
|
compiler._patch_op_graph_input_to_accept_any_integer_input() # pylint: disable=protected-access
|
|
|
|
assert compiler.compilation_configuration == default_compilation_configuration
|
|
assert compiler.compilation_configuration is not default_compilation_configuration
|
|
|
|
for i, j in zip(numpy.arange(5), numpy.arange(5, 10)):
|
|
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
|
j = numpy.ones(input_shape, dtype=numpy.int64) * j
|
|
check_array_equality(compiler(i, j), complicated_topology(i, j))
|
|
|
|
# For coverage, check that we flush the inputset when we query the OPGraph
|
|
current_op_graph = compiler.op_graph
|
|
assert current_op_graph is not compiler.op_graph
|
|
assert len(compiler._current_inputset) == 0 # pylint: disable=protected-access
|
|
# For coverage, cover case where the current inputset is empty
|
|
compiler._eval_on_current_inputset() # pylint: disable=protected-access
|
|
|
|
# Continue a bit more
|
|
for i, j in zip(numpy.arange(5, 10), numpy.arange(5)):
|
|
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
|
j = numpy.ones(input_shape, dtype=numpy.int64) * j
|
|
check_array_equality(compiler(i, j), complicated_topology(i, j))
|
|
|
|
if input_shape == ():
|
|
assert (
|
|
(got := format_operation_graph(compiler.op_graph))
|
|
== """ %0 = 67 # ClearScalar<uint7>
|
|
%1 = 2 # ClearScalar<uint2>
|
|
%2 = 3 # ClearScalar<uint2>
|
|
%3 = 1 # ClearScalar<uint1>
|
|
%4 = x # EncryptedScalar<uint4>
|
|
%5 = y # ClearScalar<uint4>
|
|
%6 = add(%4, %5) # EncryptedScalar<uint4>
|
|
%7 = add(%6, %1) # EncryptedScalar<uint4>
|
|
%8 = add(%6, %3) # EncryptedScalar<uint4>
|
|
%9 = astype(%7, dtype=int32) # EncryptedScalar<uint4>
|
|
%10 = add(%7, %2) # EncryptedScalar<uint5>
|
|
%11 = add(%8, %7) # EncryptedScalar<uint5>
|
|
%12 = astype(%10, dtype=int32) # EncryptedScalar<uint5>
|
|
%13 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
|
%14 = astype(%11, dtype=int32) # EncryptedScalar<uint5>
|
|
%15 = add(%14, %0) # EncryptedScalar<uint7>
|
|
(%13, %9, %12, %15)"""
|
|
), got
|
|
else:
|
|
assert (
|
|
(got := format_operation_graph(compiler.op_graph))
|
|
== """ %0 = 67 # ClearScalar<uint7>
|
|
%1 = 2 # ClearScalar<uint2>
|
|
%2 = 3 # ClearScalar<uint2>
|
|
%3 = 1 # ClearScalar<uint1>
|
|
%4 = x # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%5 = y # ClearTensor<uint4, shape=(3, 1, 2)>
|
|
%6 = add(%4, %5) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%7 = add(%6, %1) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%8 = add(%6, %3) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%9 = astype(%7, dtype=int32) # EncryptedTensor<uint4, shape=(3, 1, 2)>
|
|
%10 = add(%7, %2) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
|
%11 = add(%8, %7) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
|
%12 = astype(%10, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
|
%13 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
|
%14 = astype(%11, dtype=int32) # EncryptedTensor<uint5, shape=(3, 1, 2)>
|
|
%15 = add(%14, %0) # EncryptedTensor<uint7, shape=(3, 1, 2)>
|
|
(%13, %9, %12, %15)"""
|
|
), got
|
|
|
|
|
|
def remaining_inputset_size(inputset_len):
|
|
"""Small function to generate test cases below for remaining inputset length."""
|
|
return inputset_len % NPFHECompiler.INPUTSET_SIZE_BEFORE_AUTO_BOUND_UPDATE
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"inputset_len, expected_remaining_inputset_len",
|
|
[
|
|
(42, remaining_inputset_size(42)),
|
|
(128, remaining_inputset_size(128)),
|
|
(234, remaining_inputset_size(234)),
|
|
],
|
|
)
|
|
def test_np_fhe_compiler_auto_flush(
|
|
inputset_len,
|
|
expected_remaining_inputset_len,
|
|
default_compilation_configuration,
|
|
check_array_equality,
|
|
):
|
|
"""Test the auto flush of NPFHECompiler once the inputset is 128 elements."""
|
|
|
|
def function_to_compile(x):
|
|
return x // 2
|
|
|
|
compiler = NPFHECompiler(
|
|
function_to_compile,
|
|
{"x": "encrypted"},
|
|
default_compilation_configuration,
|
|
)
|
|
|
|
for i in numpy.arange(inputset_len):
|
|
check_array_equality(compiler(i), function_to_compile(i))
|
|
|
|
# Check the inputset was properly flushed
|
|
assert (
|
|
len(compiler._current_inputset) # pylint: disable=protected-access
|
|
== expected_remaining_inputset_len
|
|
)
|
|
|
|
|
|
def test_np_fhe_compiler_full_compilation(default_compilation_configuration, check_array_equality):
|
|
"""Test the case where we generate an FHE circuit."""
|
|
|
|
def function_to_compile(x):
|
|
return x + 42
|
|
|
|
compiler = NPFHECompiler(
|
|
function_to_compile,
|
|
{"x": "encrypted"},
|
|
default_compilation_configuration,
|
|
)
|
|
|
|
# For coverage
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
compiler.get_compiled_fhe_circuit()
|
|
|
|
assert str(excinfo.value) == (
|
|
"Requested FHECircuit but no OPGraph was compiled. "
|
|
"Did you forget to evaluate NPFHECompiler over an inputset?"
|
|
)
|
|
|
|
for i in numpy.arange(64):
|
|
check_array_equality(compiler(i), function_to_compile(i))
|
|
|
|
fhe_circuit = compiler.get_compiled_fhe_circuit()
|
|
|
|
for i in range(64):
|
|
assert fhe_circuit.run(i) == function_to_compile(i)
|