Files
concrete/frontends/concrete-python/tests/execution/test_others.py

1267 lines
37 KiB
Python

"""
Tests of execution of operations converted to table lookups.
"""
import numpy as np
import pytest
from concrete import fhe
def fusable_with_bigger_search(x, y):
"""
Fusable function that requires a single iteration for fusing.
"""
x = x + 1
x_1 = x.astype(np.int64)
x_1 = x_1 + 1.5
x_2 = x.astype(np.int64)
x_2 = x_2 + 3.4
add = x_1 + x_2
add_int = add.astype(np.int64)
return add_int + y
def fusable_with_bigger_search_needs_second_iteration(x, y):
"""
Fusable function that requires more than one iteration for fusing.
"""
x = x + 1
x = x + 0.5
x = np.cos(x)
x_1 = x.astype(np.int64)
x_1 = x_1 + 1.5
x_p = x + 1
x_p2 = x_p + 1
x_2 = (x_p + x_p2).astype(np.int64)
x_2 = x_2 + 3.4
add = x_1 + x_2
add_int = add.astype(np.int64)
return add_int + y
def fusable_with_one_of_the_start_nodes_is_lca_generator():
"""
Generator of a fusable function that has one of the start nodes as lca.
"""
# pylint: disable=invalid-name,too-many-locals,too-many-statements
def subgraph_18(x):
t0 = 0
t1 = 3
t2 = 2
t3 = 2.4688520431518555
t4 = 2.4688520431518555
t5 = x
t6 = np.multiply(t4, t5)
t7 = np.true_divide(t6, t3)
t8 = np.add(t7, t2)
t9 = np.rint(t8)
t10 = np.clip(t9, t0, t1)
t11 = t10.astype(np.int64)
return t11
def subgraph_24(x):
t0 = 0
t1 = [0.15588106, -0.01305565]
t2 = 1.3664466152828822
t3 = [[4, -4]]
t4 = 0
t5 = x
t6 = t5.astype(np.float32)
t7 = np.add(t6, t4)
t8 = np.add(t7, t3)
t9 = np.multiply(t2, t8)
t10 = np.add(t1, t9)
t11 = np.greater(t10, t0)
return t11
cst0 = np.random.randint(-2, 2, size=(10, 2))
cst1 = np.random.randint(0, 2, size=(10, 1))
def function(x):
t0 = 0
t1 = 3
t2 = 1
t3 = 1.2921873902965313
t4 = 1.0507009873554805
t5 = 1
t6 = 1.7580993408473766
t7 = [0.15588106, -0.01305565]
t8 = 1
t9 = 1.3664466152828822
t10 = [[4, -4]]
t11 = 0
t12 = cst0
t13 = 0
t14 = cst1
t15 = x
t16 = -2
t17 = np.add(t15, t16)
t18 = subgraph_18(t17)
t19 = np.matmul(t18, t12)
t20 = np.matmul(t18, t14)
t21 = np.multiply(t13, t20)
t22 = np.add(t19, t21)
t23 = t22.astype(np.float32)
t24 = subgraph_24(t22)
t25 = np.add(t23, t11)
t26 = np.subtract(t5, t24)
t27 = np.add(t25, t10)
t28 = np.multiply(t9, t27)
t29 = np.add(t7, t28)
t30 = np.multiply(t4, t29)
t31 = np.exp(t29)
t32 = np.multiply(t24, t30)
t33 = np.subtract(t31, t8)
t34 = np.multiply(t6, t33)
t35 = np.multiply(t26, t34)
t36 = np.add(t32, t35)
t37 = np.true_divide(t36, t3)
t38 = np.add(t37, t2)
t39 = np.rint(t38)
t40 = np.clip(t39, t0, t1)
t41 = t40.astype(np.int64)
return t41
return function
# pylint: enable=invalid-name,too-many-locals,too-many-statements
def fusable_with_hard_to_find_lca(x):
"""
Fusable function that requires harder lca search.
"""
a = x * 3
b = x // 3
c = a + b
return ((np.sin(a) ** 2) + (np.cos(c) ** 2)).round().astype(np.int64)
def fusable_with_hard_to_find_lca_used_twice(x):
"""
Fusable function that uses `fusable_with_hard_to_find_lca` twice.
"""
a = x @ np.array([[3, 1], [4, 2]])
b = x @ np.array([[1, 2], [3, 4]])
a = fusable_with_hard_to_find_lca(a)
b = fusable_with_hard_to_find_lca(b)
return a + b
def fusable_additional_1(x):
"""
Another fusable function for additional safety.
"""
a = x.astype(np.float64) * 3.0
b = x + 1
c = a.astype(np.int64)
return (a + b + c).astype(np.int64)
def fusable_additional_2(x):
"""
Another fusable function for additional safety.
"""
a = x.astype(np.float64) / 3.0
b = x + 1
c = a * a
return (a + b + c).astype(np.int64)
def deterministic_unary_function(x):
"""
An example deterministic unary function.
"""
def per_element(element):
result = 0
for i in range(element):
result += i
return result
return np.vectorize(per_element)(x)
def copy_modify(x):
"""
A function that used `np.copy` and then modifies the copied object.
"""
y = np.copy(x)
y[1] = np.sum(x)
return np.concatenate((x, y))
def issue650(x):
"""
Function of a reported bug in which bit widths assigned to clear values were wrong.
"""
tmp0 = x
tmp1 = [[1], [-1], [-1], [-1], [-1], [-1], [-2], [-1], [0], [0]]
tmp2 = np.matmul(tmp0, tmp1)
tmp3 = np.sum(tmp0, axis=1, keepdims=True)
tmp4 = -1
tmp5 = np.multiply(tmp4, tmp3)
tmp6 = np.subtract(tmp2, tmp5)
tmp7 = [[11]]
tmp8 = np.add(tmp6, tmp7)
return tmp8
@pytest.mark.parametrize(
"function,parameters,configuration_overrides",
[
pytest.param(
lambda x: x // 3,
{
"x": {"status": "encrypted", "range": [0, 127]},
},
{},
id="x // 3",
),
pytest.param(
lambda x: 127 // x,
{
"x": {"status": "encrypted", "range": [1, 127]},
},
{},
id="127 // x",
),
pytest.param(
lambda x: (x / 3).astype(np.int64),
{
"x": {"status": "encrypted", "range": [0, 127]},
},
{},
id="(x / 3).astype(np.int64)",
),
pytest.param(
lambda x: (127 / x).astype(np.int64),
{
"x": {"status": "encrypted", "range": [1, 127]},
},
{},
id="(127 / x).astype(np.int64)",
),
pytest.param(
lambda x: x**2,
{
"x": {"status": "encrypted", "range": [0, 11]},
},
{},
id="x ** 2",
),
pytest.param(
lambda x: 2**x,
{
"x": {"status": "encrypted", "range": [0, 6]},
},
{},
id="2 ** x",
),
pytest.param(
lambda x: x % 10,
{
"x": {"status": "encrypted", "range": [0, 127]},
},
{},
id="x % 10",
),
pytest.param(
lambda x: 121 % x,
{
"x": {"status": "encrypted", "range": [1, 127]},
},
{},
id="121 % x",
),
pytest.param(
lambda x: +x,
{
"x": {"status": "encrypted", "range": [0, 127]},
},
{},
id="+x",
),
pytest.param(
lambda x: abs(42 - x),
{
"x": {"status": "encrypted", "range": [0, 84]},
},
{},
id="abs(42 - x)",
),
pytest.param(
lambda x: ~x,
{
"x": {"status": "encrypted", "range": [0, 16]},
},
{},
id="~x",
),
pytest.param(
lambda x: x & 10,
{
"x": {"status": "encrypted", "range": [0, 16]},
},
{},
id="x & 10",
),
pytest.param(
lambda x: 5 & x,
{
"x": {"status": "encrypted", "range": [0, 16]},
},
{},
id="5 & x",
),
pytest.param(
lambda x: x | 6,
{
"x": {"status": "encrypted", "range": [0, 16]},
},
{},
id="x | 6",
),
pytest.param(
lambda x: 11 | x,
{
"x": {"status": "encrypted", "range": [0, 16]},
},
{},
id="11 | x",
),
pytest.param(
lambda x: x ^ 9,
{
"x": {"status": "encrypted", "range": [0, 16]},
},
{},
id="x ^ 9",
),
pytest.param(
lambda x: 13 ^ x,
{
"x": {"status": "encrypted", "range": [0, 16]},
},
{},
id="13 ^ x",
),
pytest.param(
lambda x: x << 2,
{
"x": {"status": "encrypted", "range": [0, 16]},
},
{},
id="x << 2",
),
pytest.param(
lambda x: 2 << x,
{
"x": {"status": "encrypted", "range": [0, 5]},
},
{},
id="2 << x",
),
pytest.param(
lambda x: x >> 2,
{
"x": {"status": "encrypted", "range": [0, 120]},
},
{},
id="x >> 2",
),
pytest.param(
lambda x: 120 >> x,
{
"x": {"status": "encrypted", "range": [0, 16]},
},
{},
id="120 >> x",
),
pytest.param(
lambda x: x > 50,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="x > 50",
),
pytest.param(
lambda x: 50 > x,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="50 > x",
),
pytest.param(
lambda x: x < 50,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="x < 50",
),
pytest.param(
lambda x: 50 < x,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="50 < x",
),
pytest.param(
lambda x: x >= 50,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="x >= 50",
),
pytest.param(
lambda x: 50 >= x,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="50 >= x",
),
pytest.param(
lambda x: x <= 50,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="x <= 50",
),
pytest.param(
lambda x: 50 <= x,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="50 <= x",
),
pytest.param(
lambda x: x == 50,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="x == 50",
),
pytest.param(
lambda x: 50 == x,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="50 == x",
),
pytest.param(
lambda x: x != 50,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="x != 50",
),
pytest.param(
lambda x: 50 != x,
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="50 != x",
),
pytest.param(
lambda x: x.clip(5, 10),
{
"x": {"status": "encrypted", "range": [0, 15]},
},
{},
id="x.clip(5, 10)",
),
pytest.param(
lambda x: (60 * np.sin(x)).astype(np.int64) + 60,
{
"x": {"status": "encrypted", "range": [0, 127]},
},
{},
id="(60 * np.sin(x)).astype(np.int64) + 60",
),
pytest.param(
lambda x: ((np.sin(x) ** 2) + (np.cos(x) ** 2)).round().astype(np.int64),
{
"x": {"status": "encrypted", "range": [0, 127]},
},
{},
id="((np.sin(x) ** 2) + (np.cos(x) ** 2)).round().astype(np.int64)",
),
pytest.param(
lambda x: np.maximum(x, [[10, 20], [30, 40], [50, 60]]),
{
"x": {"status": "encrypted", "range": [0, 127], "shape": (3, 2)},
},
{},
id="np.maximum(x, [[10, 20], [30, 40], [50, 60]])",
),
pytest.param(
fusable_with_bigger_search,
{
"x": {"status": "encrypted", "range": [5, 10]},
"y": {"status": "encrypted", "range": [5, 10]},
},
{},
id="fusable_with_bigger_search",
),
pytest.param(
fusable_with_bigger_search_needs_second_iteration,
{
"x": {"status": "encrypted", "range": [5, 10]},
"y": {"status": "encrypted", "range": [5, 10]},
},
{},
id="fusable_with_bigger_search_needs_second_iteration",
),
pytest.param(
fusable_with_one_of_the_start_nodes_is_lca_generator(),
{
"x": {"status": "encrypted", "range": [0, 4], "shape": (1, 10)},
},
{},
id="fusable_with_one_of_the_start_nodes_is_lca",
),
pytest.param(
fusable_with_hard_to_find_lca,
{
"x": {"status": "encrypted", "range": [0, 10]},
},
{},
id="fusable_with_hard_to_find_lca",
),
pytest.param(
fusable_with_hard_to_find_lca_used_twice,
{
"x": {"status": "encrypted", "range": [0, 4], "shape": (2, 2)},
},
{},
id="fusable_with_hard_to_find_lca_used_twice",
),
pytest.param(
fusable_additional_1,
{
"x": {"status": "encrypted", "range": [0, 10]},
},
{},
id="fusable_additional_1",
),
pytest.param(
fusable_additional_2,
{
"x": {"status": "encrypted", "range": [0, 10]},
},
{},
id="fusable_additional_2",
),
pytest.param(
lambda x: x + x.shape[0] + x.ndim + x.size + len(x),
{
"x": {"status": "encrypted", "range": [0, 15], "shape": (3, 2)},
},
{},
id="x + x.shape[0] + x.ndim + x.size + len(x)",
),
pytest.param(
lambda x: (50 * np.sin(x.transpose())).astype(np.int64),
{
"x": {"status": "encrypted", "range": [0, 15], "shape": (3, 2)},
},
{},
id="(50 * np.sin(x.transpose())).astype(np.int64)",
),
pytest.param(
lambda x: np.where(x < 5, x * 3, x),
{
"x": {"status": "encrypted", "range": [0, 10]},
},
{},
id="np.where(x < 5, x * 3, x)",
),
pytest.param(
lambda x: x + np.ones_like(x),
{
"x": {"status": "encrypted", "range": [0, 10]},
},
{},
id="x + np.ones_like(x)",
),
pytest.param(
lambda x: x + np.zeros_like(x),
{
"x": {"status": "encrypted", "range": [0, 10]},
},
{},
id="x + np.zeros_like(x)",
),
pytest.param(
lambda x: fhe.univariate(deterministic_unary_function)(x),
{
"x": {"status": "encrypted", "range": [0, 10]},
},
{},
id="fhe.univariate(deterministic_unary_function)(x)",
),
pytest.param(
lambda x: round(np.sqrt(x)),
{
"x": {"status": "encrypted", "range": [0, 100], "shape": ()},
},
{},
id="round(np.sqrt(x))",
),
pytest.param(
lambda x: np.sqrt(x).round().astype(np.int64),
{
"x": {"status": "encrypted", "range": [0, 100]},
},
{},
id="np.sqrt(x).round().astype(np.int64)",
),
pytest.param(
lambda x: (2.5 * round(np.sqrt(x), ndigits=4)).astype(np.int64),
{
"x": {"status": "encrypted", "range": [0, 100], "shape": ()},
},
{},
id="(2.5 * round(np.sqrt(x), decimals=4)).astype(np.int64)",
),
pytest.param(
lambda x, y: fhe.LookupTable(list(range(32)))[x + y],
{
"x": {"status": "encrypted", "range": [-10, 10]},
"y": {"status": "encrypted", "range": [-10, 10]},
},
{},
id="fhe.LookupTable(list(range(32)))[x + y]",
),
pytest.param(
lambda x: np.expand_dims(x, 0),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
{},
id="np.expand_dims(x, 0)",
),
pytest.param(
lambda x: np.expand_dims(x, axis=0),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
{},
id="np.expand_dims(x, axis=0)",
),
pytest.param(
lambda x: np.expand_dims(x, axis=1),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
{},
id="np.expand_dims(x, axis=1)",
),
pytest.param(
lambda x: np.expand_dims(x, axis=2),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
{},
id="np.expand_dims(x, axis=2)",
),
pytest.param(
lambda x: np.expand_dims(x, axis=(0, 1)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
{},
id="np.expand_dims(x, axis=(0, 1))",
),
pytest.param(
lambda x: np.expand_dims(x, axis=(0, 2)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
{},
id="np.expand_dims(x, axis=(0, 2))",
),
pytest.param(
lambda x: np.expand_dims(x, axis=(1, 2)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
{},
id="np.expand_dims(x, axis=(1, 2))",
),
pytest.param(
lambda x: np.expand_dims(x, axis=(0, 1, 2)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
{},
id="np.expand_dims(x, axis=(0, 1, 2))",
),
pytest.param(
lambda x: x**3,
{
"x": {"status": "encrypted", "range": [-30, 30]},
},
{},
id="x ** 3",
),
pytest.param(
lambda x: np.squeeze(x),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": ()},
},
{},
id="np.squeeze(x)",
),
pytest.param(
lambda x: np.squeeze(x),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)},
},
{},
id="np.squeeze(x)",
),
pytest.param(
lambda x: np.squeeze(x, axis=2),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)},
},
{},
id="np.squeeze(x, axis=2)",
),
pytest.param(
lambda x: np.squeeze(x, axis=(0, 4)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)},
},
{},
id="np.squeeze(x, axis=(0, 4))",
),
pytest.param(
lambda x: np.squeeze(x),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 1)},
},
{},
id="np.squeeze(x) where x.shape == (1, 1, 1)",
),
pytest.param(
lambda x: np.squeeze(x, axis=1),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 1)},
},
{},
id="np.squeeze(x, axis=1) where x.shape == (1, 1, 1)",
),
pytest.param(
lambda x: fhe.LookupTable([10, 5])[x > 5],
{
"x": {"status": "encrypted", "range": [0, 10]},
},
{},
id="fhe.LookupTable([10, 5])[x > 5]",
),
pytest.param(
copy_modify,
{
"x": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
},
{},
id="copy_modify",
),
pytest.param(
lambda x: fhe.ones_like(x) + x,
{
"x": {"status": "encrypted", "range": [0, 4]},
},
{},
id="fhe.ones_like(x) + x",
),
pytest.param(
lambda x: fhe.zeros_like(x) + x,
{
"x": {"status": "encrypted", "range": [0, 4]},
},
{},
id="fhe.zeros_like(x) + x",
),
pytest.param(
lambda x: np.minimum(x, 0),
{
"x": {"status": "encrypted", "range": [-10, 10]},
},
{},
id="np.minimum(x, 0)",
),
pytest.param(
lambda x: np.maximum(x, 0),
{
"x": {"status": "encrypted", "range": [-10, 10]},
},
{},
id="np.maximum(x, 0)",
),
pytest.param(
lambda x: x + np.zeros_like(x),
{
"x": {"status": "encrypted", "range": [0, 10]},
},
{},
id="x + np.zeros_like(x)",
),
pytest.param(
lambda x: (x**2, x + 100),
{
"x": {"range": [12, 13], "status": "encrypted", "shape": ()},
},
{},
id="(x**2, x + 100) [x: [12, 13]] ",
),
pytest.param(
lambda x: (x**2, x + 100),
{
"x": {"range": [12, 13], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_measured_bounds": True,
},
id="(x**2, x + 100) [x: [12, 13]] {optimize_tlu_based_on_measured_bounds: True}",
),
pytest.param(
lambda x: fhe.univariate(lambda x: x // [2, 3])(x),
{
"x": {"range": [-12, -11], "status": "encrypted", "shape": (2,)},
},
{
"optimize_tlu_based_on_measured_bounds": True,
},
id=(
"fhe.univariate(lambda x: x // [2, 3])(x) [x: [-12, -11]] "
"{optimize_tlu_based_on_measured_bounds: True}"
),
),
pytest.param(
lambda x: fhe.univariate(lambda x: x // np.array([2, 3]))(x),
{
"x": {"range": [12, 15], "status": "encrypted", "shape": (2,)},
},
{
"optimize_tlu_based_on_measured_bounds": True,
},
id=(
"fhe.univariate(lambda x: x // np.array([2, 3]))(x) "
"[x: [12, 15]] "
"{optimize_tlu_based_on_measured_bounds: True}"
),
),
pytest.param(
lambda x: (fhe.hint(x, bit_width=5) ** 2, x + 100),
{
"x": {"range": [12, 13], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_measured_bounds": True,
},
id=(
"(fhe.hint(x, bit_width=5)**2, x + 100) "
"[x: [12, 15]] "
"{optimize_tlu_based_on_measured_bounds: True}"
),
),
pytest.param(
lambda x: fhe.univariate(lambda x: x // np.array([2, 3]))(fhe.hint(x, bit_width=5)),
{
"x": {"range": [12, 15], "status": "encrypted", "shape": (2,)},
},
{
"optimize_tlu_based_on_measured_bounds": True,
},
id=(
"fhe.univariate(lambda x: x // np.array([2, 3]))(fhe.hint(x, bit_width=5)) "
"[x: [12, 15]] "
"{optimize_tlu_based_on_measured_bounds: True}"
),
),
pytest.param(
lambda x: (x // 2, x + 100),
{
"x": {"range": [1, 63], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_measured_bounds": True,
},
id=("(x // 2, x + 100) [x: [1, 63]] {optimize_tlu_based_on_measured_bounds: True}"),
),
pytest.param(
lambda x: (x**2, x + 100),
{
"x": {"range": [-13, -12], "status": "encrypted", "shape": ()},
},
{},
id=("(x**2, x + 100) [x: [-13, -12]]"),
),
pytest.param(
lambda x: (x**2, x + 100),
{
"x": {"range": [-13, -12], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_measured_bounds": True,
},
id=(
"(x**2, x + 100) "
"[x: [-13, -12]] "
"{optimize_tlu_based_on_measured_bounds: True}"
),
),
pytest.param(
lambda x: (x // 2, x + 100),
{
"x": {"range": [-32, 31], "status": "encrypted", "shape": ()},
},
{
"optimize_tlu_based_on_measured_bounds": True,
},
id=(
"(x // 2, x + 100) "
"[x: [-32, 31]] "
"{optimize_tlu_based_on_measured_bounds: True}"
),
),
pytest.param(
issue650,
{
"x": {"range": [-2, 1], "status": "encrypted", "shape": (1, 10)},
},
{},
id="issue-650",
),
pytest.param(
lambda x: fhe.univariate(lambda x: (-3) * (1.0 - (x.astype(np.float64) * 0.0)))(
x
).astype(np.int64),
{
"x": {"range": [-64, 63], "status": "encrypted", "shape": (1,)},
},
{},
id="issue-651",
),
pytest.param(
lambda x: x + (x // 3),
{
"x": {"range": [0, 2**14 - 1], "status": "encrypted", "shape": ()},
},
{},
id="x + (x // 3)",
),
pytest.param(
lambda x: (x**3, x + 100),
{
"x": {"range": [-(2**3), 2**3 - 1], "status": "encrypted", "shape": ()},
},
{},
id="(x ** 3, x + 100)",
),
pytest.param(
lambda x: np.min(x, 0),
{
"x": {"range": [0, 10], "status": "encrypted", "shape": (2, 2)},
},
{},
id="np.min(x, 0)",
),
pytest.param(
lambda x: (x + 20, fhe.bits(x)[1]),
{
"x": {"status": "encrypted", "range": [0, 3]},
},
{},
id="x + 20, fhe.bits(x)[1]",
),
],
)
def test_others(function, parameters, configuration_overrides, helpers):
"""
Test others.
"""
# scalar
# ------
if "shape" not in parameters["x"] or parameters["x"]["shape"] == ():
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample, retries=3)
# tensor
# ------
if "shape" not in parameters["x"]:
parameters["x"]["shape"] = (3, 2)
if parameters["x"]["shape"] == ():
return
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration().fork(**configuration_overrides)
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample, retries=3)
def test_others_bad_fusing(helpers):
"""
Test others with bad fusing.
"""
configuration = helpers.configuration()
# two variable inputs
# -------------------
@fhe.compiler({"x": "encrypted", "y": "clear"})
def function1(x, y):
return (10 * (np.sin(x) ** 2) + 10 * (np.cos(y) ** 2)).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
inputset = [(i, i) for i in range(100)]
function1.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
"""
A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes
%0 = x # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
%1 = y # ClearScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
%2 = sin(%0) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%3 = 2 # ClearScalar<uint2>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%4 = power(%2, %3) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%5 = 10 # ClearScalar<uint4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%6 = multiply(%5, %4) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%7 = cos(%1) # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%8 = 2 # ClearScalar<uint2>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%9 = power(%7, %8) # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%10 = 10 # ClearScalar<uint4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%11 = multiply(%10, %9) # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%12 = add(%6, %11) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%13 = astype(%12, dtype=int_) # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
return %13
""", # noqa: E501
# pylint: enable=line-too-long
str(excinfo.value),
)
# intermediates with different shape
# ----------------------------------
@fhe.compiler({"x": "encrypted"})
def function2(x):
return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
inputset = [np.random.randint(2**6, 2**7, size=(3, 2)) for _ in range(100)]
function2.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
"""
A subgraph within the function you are trying to compile cannot be fused because of a node, which is has a different shape than the input node
%0 = x # EncryptedTensor<uint7, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with this input node
%1 = sin(%0) # EncryptedTensor<float64, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%2 = absolute(%1) # EncryptedTensor<float64, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%3 = reshape(%2, newshape=(2, 3)) # EncryptedTensor<float64, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%4 = astype(%3, dtype=int_) # EncryptedTensor<uint1, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node has a different shape than the input node
return %4
""", # noqa: E501
# pylint: enable=line-too-long
str(excinfo.value),
)
# non-fusable operation
# ---------------------
@fhe.compiler({"x": "encrypted"})
def function3(x):
return np.abs(np.sin(x)).transpose().astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
inputset = [[[0, 1], [2, 3]]]
function3.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
"""
A subgraph within the function you are trying to compile cannot be fused because of a node, which is marked explicitly as non-fusable
%0 = x # EncryptedTensor<uint2, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with this input node
%1 = sin(%0) # EncryptedTensor<float64, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%2 = absolute(%1) # EncryptedTensor<float64, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
%3 = transpose(%2) # EncryptedTensor<float64, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is not fusable
%4 = astype(%3, dtype=int_) # EncryptedTensor<uint1, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
return %4
""", # noqa: E501
# pylint: enable=line-too-long
str(excinfo.value),
)
# integer two variable inputs
# ---------------------------
@fhe.compiler({"x": "encrypted", "y": "clear"})
def function4(x, y):
return np.maximum(x, y)
with pytest.raises(RuntimeError) as excinfo:
inputset = [(i, i) for i in range(100)]
function4.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
"""
A subgraph within the function you are trying to compile cannot be fused because it has multiple input nodes
%0 = x # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
%1 = y # ClearScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this is one of the input nodes
%2 = maximum(%0, %1) # EncryptedScalar<uint1>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ within this subgraph
return %2
""", # noqa: E501
# pylint: enable=line-too-long
str(excinfo.value),
)
def test_others_bad_univariate(helpers):
"""
Test univariate with bad function.
"""
configuration = helpers.configuration()
def bad_univariate(x):
return np.array([x, x, x])
@fhe.compiler({"x": "encrypted"})
def f(x):
return fhe.univariate(bad_univariate)(x)
with pytest.raises(ValueError) as excinfo:
inputset = range(10)
f.compile(inputset, configuration)
helpers.check_str(
"Function bad_univariate cannot be used with fhe.univariate",
str(excinfo.value),
)
def test_dynamic_indexing_hack(helpers):
"""
Test dynamic indexing using basic operators.
"""
@fhe.compiler({"array": "encrypted", "index": "encrypted"})
def function(array, index):
all_indices = np.arange(array.size)
index_selection = index == all_indices
selection_and_zeros = array * index_selection
selection = np.sum(selection_and_zeros)
return selection
inputset = [
(
np.random.randint(0, 16, size=(4,)),
np.random.randint(0, 4, size=()),
)
for _ in range(100)
]
circuit = function.compile(inputset, helpers.configuration())
sample = np.random.randint(0, 16, size=(4,))
helpers.check_execution(circuit, function, [sample, 0], retries=3)
helpers.check_execution(circuit, function, [sample, 1], retries=3)
helpers.check_execution(circuit, function, [sample, 2], retries=3)
helpers.check_execution(circuit, function, [sample, 3], retries=3)