feat(compiler): add more detailed statistics

This commit is contained in:
Umut
2023-07-26 14:27:32 +02:00
parent dae31f0f26
commit ade83d5335
14 changed files with 1407 additions and 272 deletions

View File

@@ -523,6 +523,20 @@ def test_circuit_compile_sim_only(helpers):
assert f(*inputset[0]) == circuit.simulate(*inputset[0])
def tagged_function(x, y, z):
"""
A tagged function to test statistics.
"""
with fhe.tag("a"):
x = fhe.univariate(lambda v: v)(x)
with fhe.tag("b"):
y = fhe.univariate(lambda v: v)(y)
with fhe.tag("c"):
z = fhe.univariate(lambda v: v)(z)
return x + y + z
@pytest.mark.parametrize(
"function,parameters,expected_statistics",
[
@@ -532,12 +546,11 @@ def test_circuit_compile_sim_only(helpers):
"x": {"status": "encrypted", "range": [0, 10], "shape": ()},
},
{
"total_pbs_count": 1,
"total_ks_count": 1,
"total_clear_addition_count": 0,
"total_encrypted_addition_count": 0,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 0,
"programmable_bootstrap_count": 1,
"clear_addition_count": 0,
"encrypted_addition_count": 0,
"clear_multiplication_count": 0,
"encrypted_negation_count": 0,
},
id="x**2 | x.is_encrypted | x.shape == ()",
),
@@ -547,11 +560,11 @@ def test_circuit_compile_sim_only(helpers):
"x": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
},
{
"total_pbs_count": 3,
"total_clear_addition_count": 0,
"total_encrypted_addition_count": 0,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 0,
"programmable_bootstrap_count": 3,
"clear_addition_count": 0,
"encrypted_addition_count": 0,
"clear_multiplication_count": 0,
"encrypted_negation_count": 0,
},
id="x**2 | x.is_encrypted | x.shape == (3,)",
),
@@ -561,11 +574,11 @@ def test_circuit_compile_sim_only(helpers):
"x": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)},
},
{
"total_pbs_count": 3 * 2,
"total_clear_addition_count": 0,
"total_encrypted_addition_count": 0,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 0,
"programmable_bootstrap_count": 3 * 2,
"clear_addition_count": 0,
"encrypted_addition_count": 0,
"clear_multiplication_count": 0,
"encrypted_negation_count": 0,
},
id="x**2 | x.is_encrypted | x.shape == (3, 2)",
),
@@ -576,11 +589,11 @@ def test_circuit_compile_sim_only(helpers):
"y": {"status": "encrypted", "range": [0, 10], "shape": ()},
},
{
"total_pbs_count": 2,
"total_clear_addition_count": 1,
"total_encrypted_addition_count": 3,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 2,
"programmable_bootstrap_count": 2,
"clear_addition_count": 1,
"encrypted_addition_count": 3,
"clear_multiplication_count": 0,
"encrypted_negation_count": 2,
},
id="x * y | x.is_encrypted | x.shape == () | y.is_encrypted | y.shape == ()",
),
@@ -591,11 +604,11 @@ def test_circuit_compile_sim_only(helpers):
"y": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
},
{
"total_pbs_count": 3 * 2,
"total_clear_addition_count": 3 * 1,
"total_encrypted_addition_count": 3 * 3,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 3 * 2,
"programmable_bootstrap_count": 3 * 2,
"clear_addition_count": 3 * 1,
"encrypted_addition_count": 3 * 3,
"clear_multiplication_count": 0,
"encrypted_negation_count": 3 * 2,
},
id="x * y | x.is_encrypted | x.shape == (3,) | y.is_encrypted | y.shape == (3,)",
),
@@ -606,14 +619,30 @@ def test_circuit_compile_sim_only(helpers):
"y": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)},
},
{
"total_pbs_count": 3 * 2 * 2,
"total_clear_addition_count": 3 * 2 * 1,
"total_encrypted_addition_count": 3 * 2 * 3,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 3 * 2 * 2,
"programmable_bootstrap_count": 3 * 2 * 2,
"clear_addition_count": 3 * 2 * 1,
"encrypted_addition_count": 3 * 2 * 3,
"clear_multiplication_count": 0,
"encrypted_negation_count": 3 * 2 * 2,
},
id="x * y | x.is_encrypted | x.shape == (3, 2) | y.is_encrypted | y.shape == (3, 2)",
),
pytest.param(
tagged_function,
{
"x": {"status": "encrypted", "range": [0, 2**3 - 1], "shape": ()},
"y": {"status": "encrypted", "range": [0, 2**4 - 1], "shape": ()},
"z": {"status": "encrypted", "range": [0, 2**5 - 1], "shape": ()},
},
{
"programmable_bootstrap_count_per_tag": {
"a": 3,
"a.b": 2,
"a.b.c": 1,
},
},
id="tagged_function",
),
],
)
def test_statistics(function, parameters, expected_statistics, helpers):