mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-14 07:17:57 -05:00
513 lines
17 KiB
Python
513 lines
17 KiB
Python
"""
|
|
Tests of `Node` class.
|
|
"""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from concrete.fhe import tfhers
|
|
from concrete.fhe.dtypes import UnsignedInteger
|
|
from concrete.fhe.representation import Node
|
|
from concrete.fhe.values import (
|
|
ClearScalar,
|
|
ClearTensor,
|
|
EncryptedScalar,
|
|
EncryptedTensor,
|
|
ValueDescription,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"constant,expected_error,expected_message",
|
|
[
|
|
pytest.param(
|
|
"abc",
|
|
ValueError,
|
|
"Constant 'abc' is not supported",
|
|
),
|
|
],
|
|
)
|
|
def test_node_bad_constant(constant, expected_error, expected_message):
|
|
"""
|
|
Test `constant` function of `Node` class with bad parameters.
|
|
"""
|
|
|
|
with pytest.raises(expected_error) as excinfo:
|
|
Node.constant(constant)
|
|
|
|
assert str(excinfo.value) == expected_message
|
|
|
|
|
|
def default_tfhers_dtype() -> tfhers.TFHERSIntegerType:
|
|
"""Get default tfhers type used for testing.
|
|
|
|
Returns:
|
|
tfhers.TFHERSIntegerType: default type for testing
|
|
"""
|
|
tfhers_params = tfhers.CryptoParams(
|
|
909,
|
|
1,
|
|
4096,
|
|
15,
|
|
2,
|
|
0,
|
|
2.168404344971009e-19,
|
|
tfhers.EncryptionKeyChoice.BIG,
|
|
)
|
|
return tfhers.int8_2_2(tfhers_params)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"node,args,expected_error,expected_message",
|
|
[
|
|
pytest.param(
|
|
Node.constant(1),
|
|
["abc"],
|
|
ValueError,
|
|
"Evaluation of constant '1' node using 'abc' failed "
|
|
"because of invalid number of arguments",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="add",
|
|
inputs=[
|
|
ValueDescription.of(4),
|
|
ValueDescription.of(10, is_encrypted=True),
|
|
],
|
|
output=ValueDescription.of(14),
|
|
operation=lambda x, y: x + y,
|
|
),
|
|
["abc"],
|
|
ValueError,
|
|
"Evaluation of generic 'add' node using 'abc' failed "
|
|
"because of invalid number of arguments",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="add",
|
|
inputs=[
|
|
ValueDescription.of(4),
|
|
ValueDescription.of(10, is_encrypted=True),
|
|
],
|
|
output=ValueDescription.of(14),
|
|
operation=lambda x, y: x + y,
|
|
),
|
|
["abc", "def"],
|
|
ValueError,
|
|
"Evaluation of generic 'add' node using 'abc', 'def' failed "
|
|
"because argument 'abc' is not valid",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="add",
|
|
inputs=[
|
|
ValueDescription.of([3, 4]),
|
|
ValueDescription.of(10, is_encrypted=True),
|
|
],
|
|
output=ValueDescription.of([13, 14]),
|
|
operation=lambda x, y: x + y,
|
|
),
|
|
[[1, 2, 3, 4], 10],
|
|
ValueError,
|
|
"Evaluation of generic 'add' node using [1, 2, 3, 4], 10 failed "
|
|
"because argument [1, 2, 3, 4] does not have the expected shape of (2,)",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="unknown",
|
|
inputs=[],
|
|
output=ValueDescription.of(10),
|
|
operation=lambda: "abc",
|
|
),
|
|
[],
|
|
ValueError,
|
|
"Evaluation of generic 'unknown' node resulted in 'abc' of type str "
|
|
"which is not acceptable either because of the type or because of overflow",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="unknown",
|
|
inputs=[],
|
|
output=ValueDescription.of(10),
|
|
operation=lambda: np.array(["abc", "def"]),
|
|
),
|
|
[],
|
|
ValueError,
|
|
"Evaluation of generic 'unknown' node resulted in array(['abc', 'def'], dtype='<U3') "
|
|
f"of type np.ndarray and of underlying type '{type(np.array(['abc', 'def']).dtype).__name__}' " # noqa: E501
|
|
"which is not acceptable because of the underlying type",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="unknown",
|
|
inputs=[],
|
|
output=ValueDescription.of(10),
|
|
operation=lambda: [1, (), 3],
|
|
),
|
|
[],
|
|
ValueError,
|
|
"Evaluation of generic 'unknown' node resulted in [1, (), 3] of type list "
|
|
"which is not acceptable either because of the type or because of overflow",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="unknown",
|
|
inputs=[],
|
|
output=ValueDescription.of(10),
|
|
operation=lambda: [1, 2, 3],
|
|
),
|
|
[],
|
|
ValueError,
|
|
"Evaluation of generic 'unknown' node resulted in array([1, 2, 3]) "
|
|
"which does not have the expected shape of ()",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="unknown",
|
|
inputs=[],
|
|
output=ValueDescription(default_tfhers_dtype(), (3,), True),
|
|
operation=lambda: tfhers.TFHERSInteger(default_tfhers_dtype(), [1, 2, 3]),
|
|
),
|
|
[],
|
|
ValueError,
|
|
"Evaluation of generic 'unknown' node resulted in TFHEInteger(dtype=tfhers<int8, 2, 2"
|
|
", params=crypto_params<lwe_dim=909, glwe_dim=1, poly_size=4096, pbs_base_log=15, "
|
|
"pbs_level=2, lwe_noise_distribution=0, glwe_noise_distribution=2.168404344971009e-19,"
|
|
" encryption_key_choice=BIG>>, shape=(3,), value=[1 2 3]) of type TFHERSInteger which "
|
|
"is not acceptable either because of the type or because of overflow",
|
|
id="TFHERSInteger in a non-input node",
|
|
),
|
|
],
|
|
)
|
|
def test_node_bad_call(node, args, expected_error, expected_message):
|
|
"""
|
|
Test `__call__` method of `Node` class.
|
|
"""
|
|
|
|
with pytest.raises(expected_error) as excinfo:
|
|
node(*args)
|
|
|
|
assert str(excinfo.value) == expected_message
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"node,predecessors,expected_result",
|
|
[
|
|
pytest.param(
|
|
Node.constant(1),
|
|
[],
|
|
"1",
|
|
),
|
|
pytest.param(
|
|
Node.input("x", EncryptedScalar(UnsignedInteger(3))),
|
|
[],
|
|
"x",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="tlu",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
operation=lambda x, table: table[x],
|
|
kwargs={"table": np.array([4, 1, 3, 2])},
|
|
),
|
|
["%0"],
|
|
"tlu(%0, table=[4 1 3 2])",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_static",
|
|
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3,))],
|
|
output=EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
|
operation=lambda x: x[slice(None, None, -1)],
|
|
kwargs={"index": (slice(None, None, -1),)},
|
|
),
|
|
["%0"],
|
|
"%0[::-1]",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="concatenate",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 6)),
|
|
operation=lambda *args, **kwargs: np.concatenate(tuple(args), **kwargs),
|
|
kwargs={"axis": 1},
|
|
),
|
|
["%0", "%1", "%2"],
|
|
"concatenate((%0, %1, %2), axis=1)",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="array",
|
|
inputs=[
|
|
EncryptedScalar(UnsignedInteger(3)),
|
|
ClearScalar(UnsignedInteger(3)),
|
|
ClearScalar(UnsignedInteger(3)),
|
|
EncryptedScalar(UnsignedInteger(3)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(3), shape=(2, 2)),
|
|
operation=lambda *args: np.array(args).reshape((2, 2)),
|
|
),
|
|
["%0", "%1", "%2", "%3"],
|
|
"array([[%0, %1], [%2, %3]])",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="assign_static",
|
|
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
|
|
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
|
|
operation=lambda *args: args,
|
|
kwargs={"index": (1, 2)},
|
|
),
|
|
["%0", "%1"],
|
|
"(%0[1, 2] = %1)",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8,)),
|
|
ClearTensor(UnsignedInteger(3), shape=(3,)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(3,)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (None,)},
|
|
),
|
|
["%0", "%1"],
|
|
"%0[%1]",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
ClearTensor(UnsignedInteger(3), shape=(3,)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(3,)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (None, 0)},
|
|
),
|
|
["%0", "%1"],
|
|
"%0[%1, 0]",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
ClearTensor(UnsignedInteger(2), shape=(3,)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(3,)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (0, None)},
|
|
),
|
|
["%0", "%1"],
|
|
"%0[0, %1]",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
ClearTensor(UnsignedInteger(3), shape=(3,)),
|
|
ClearTensor(UnsignedInteger(2), shape=(3,)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(3,)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (None, None)},
|
|
),
|
|
["%0", "%1", "%2"],
|
|
"%0[%1, %2]",
|
|
),
|
|
],
|
|
)
|
|
def test_node_format(node, predecessors, expected_result):
|
|
"""
|
|
Test `format` method of `Node` class.
|
|
"""
|
|
|
|
assert node.format(predecessors) == expected_result
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"node,expected_result",
|
|
[
|
|
pytest.param(
|
|
Node.constant(1),
|
|
"1",
|
|
),
|
|
pytest.param(
|
|
Node.input("x", EncryptedScalar(UnsignedInteger(3))),
|
|
"x",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="tlu",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
operation=lambda x, table: table[x],
|
|
kwargs={"table": np.array([4, 1, 3, 2])},
|
|
),
|
|
"tlu",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="concatenate",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 6)),
|
|
operation=lambda *args, **kwargs: np.concatenate(tuple(args), **kwargs),
|
|
kwargs={"axis": -1},
|
|
),
|
|
"concatenate",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_static",
|
|
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
|
|
output=EncryptedTensor(UnsignedInteger(3), shape=()),
|
|
operation=lambda *args: args,
|
|
kwargs={"index": (1, 2)},
|
|
),
|
|
"□[1, 2]",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="assign_static",
|
|
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
|
|
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
|
|
operation=lambda *args: args,
|
|
kwargs={"index": (1, 2)},
|
|
),
|
|
"□[1, 2] = □",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8,)),
|
|
ClearTensor(UnsignedInteger(3), shape=(3,)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(3,)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (None,)},
|
|
),
|
|
"□[□]",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
ClearTensor(UnsignedInteger(3), shape=(3,)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(3,)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (None, 0)},
|
|
),
|
|
"□[□, 0]",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
ClearTensor(UnsignedInteger(2), shape=(3,)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(3,)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (0, None)},
|
|
),
|
|
"□[0, □]",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="index_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
ClearTensor(UnsignedInteger(3), shape=(3,)),
|
|
ClearTensor(UnsignedInteger(2), shape=(3,)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(3,)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (None, None)},
|
|
),
|
|
"□[□, □]",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="assign_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8,)),
|
|
ClearScalar(UnsignedInteger(3)),
|
|
ClearScalar(UnsignedInteger(5)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(8,)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (None,)},
|
|
),
|
|
"□[□] = □",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="assign_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
ClearScalar(UnsignedInteger(3)),
|
|
ClearScalar(UnsignedInteger(5)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (None, 0)},
|
|
),
|
|
"□[□, 0] = □",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="assign_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
ClearScalar(UnsignedInteger(3)),
|
|
ClearScalar(UnsignedInteger(5)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (0, None)},
|
|
),
|
|
"□[0, □] = □",
|
|
),
|
|
pytest.param(
|
|
Node.generic(
|
|
name="assign_dynamic",
|
|
inputs=[
|
|
EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
ClearScalar(UnsignedInteger(3)),
|
|
ClearScalar(UnsignedInteger(2)),
|
|
ClearScalar(UnsignedInteger(5)),
|
|
],
|
|
output=EncryptedTensor(UnsignedInteger(5), shape=(8, 4)),
|
|
operation=lambda *args: args,
|
|
kwargs={"static_indices": (None, None)},
|
|
),
|
|
"□[□, □] = □",
|
|
),
|
|
],
|
|
)
|
|
def test_node_label(node, expected_result):
|
|
"""
|
|
Test `label` method of `Node` class.
|
|
"""
|
|
|
|
assert node.label() == expected_result
|