refactor(frontend-python): matmul and dot tests to be consistent with the rest of the tests

This commit is contained in:
Umut
2023-05-17 14:24:29 +02:00
parent 9e73a20d1e
commit fdfd4760f1
2 changed files with 53 additions and 64 deletions

View File

@@ -12,9 +12,9 @@ from concrete import fhe
"size",
[1, 4, 6, 10],
)
def test_dot(size, helpers):
def test_constant_dot(size, helpers):
"""
Test dot.
Test dot where one of the operators is a constant.
"""
configuration = helpers.configuration()
@@ -23,7 +23,7 @@ def test_dot(size, helpers):
cst = np.random.randint(0, bound, size=(size,))
@fhe.compiler({"x": "encrypted"})
def dot_enc_enc_function(x):
def left_function(x):
return np.dot(x, cst)
@fhe.compiler({"x": "encrypted"})
@@ -36,55 +36,50 @@ def test_dot(size, helpers):
inputset = [np.random.randint(0, bound, size=(size,)) for i in range(100)]
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
left_function_circuit = left_function.compile(inputset, configuration)
right_function_circuit = right_function.compile(inputset, configuration)
method_circuit = method.compile(inputset, configuration)
sample = np.random.randint(0, bound, size=(size,))
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample)
helpers.check_execution(left_function_circuit, left_function, sample)
helpers.check_execution(right_function_circuit, right_function, sample)
helpers.check_execution(method_circuit, method, sample)
@pytest.mark.parametrize(
"size",
[1, 10],
)
@pytest.mark.parametrize(
"bitwidth",
[2, 6],
)
@pytest.mark.parametrize("size", [1, 10])
@pytest.mark.parametrize("bit_width", [2, 6])
@pytest.mark.parametrize("signed", [True, False])
@pytest.mark.parametrize("negative_only", [True, False])
def test_dot_enc_enc(size, bitwidth, negative_only, signed, helpers):
@pytest.mark.parametrize("only_negative", [True, False])
def test_dot(size, bit_width, only_negative, signed, helpers):
"""
Test dot.
"""
configuration = helpers.configuration()
minv = 0 if not signed else -(2 ** (bitwidth - 1))
minimum = 0 if not signed else -(2 ** (bit_width - 1))
maximum = 2**bit_width if not signed else 2 ** (bit_width - 1)
# +1 since randint max is not inclusive
maxv = 2**bitwidth if not signed else 2 ** (bitwidth - 1)
if negative_only:
maxv = 1
if only_negative:
maximum = 1
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def dot_enc_enc_function(x, y):
def function(x, y):
return np.dot(x, y)
inputset = [
(np.random.randint(minv, maxv, size=(size,)), np.random.randint(minv, maxv, size=(size,)))
(
np.random.randint(minimum, maximum, size=(size,)),
np.random.randint(minimum, maximum, size=(size,)),
)
for i in range(100)
]
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
circuit = function.compile(inputset, configuration)
sample = [
np.random.randint(minv, maxv, size=(size,)),
np.random.randint(minv, maxv, size=(size,)),
np.random.randint(minimum, maximum, size=(size,)),
np.random.randint(minimum, maximum, size=(size,)),
]
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample, retries=3)
helpers.check_execution(circuit, function, sample, retries=3)

View File

@@ -123,9 +123,9 @@ from concrete import fhe
),
],
)
def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
def test_constant_matmul(lhs_shape, rhs_shape, bounds, helpers):
"""
Test matmul.
Test matmul where one of the operators is a constant.
"""
configuration = helpers.configuration()
@@ -288,7 +288,7 @@ def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
),
],
)
def test_matmul_enc_enc_and_clear(lhs_shape, rhs_shape, bounds, helpers):
def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
"""
Test matmul.
"""
@@ -297,69 +297,63 @@ def test_matmul_enc_enc_and_clear(lhs_shape, rhs_shape, bounds, helpers):
minimum, maximum = bounds
# Matmul of clear values and encrypted matrices
@fhe.compiler({"x": "encrypted", "y": "clear"})
def lhs_operator_clear(x, y):
def clear(x, y):
return x @ y
# Matmul of two encrypted matrices
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def enc_function_xy(x, y):
def encrypted(x, y):
return np.matmul(x, y)
# Put all the dual operand functions in a list
# FIXME: add lhs_operator_clear to this list to
# re-enable the testing with clear values
dual_operand_functions = [enc_function_xy]
# Compile each dual operand function and test it on random data
for func in dual_operand_functions:
dual_op_inputset = [
for implementation in [clear, encrypted]:
inputset = [
(
np.random.randint(minimum, maximum, size=lhs_shape),
np.random.randint(minimum, maximum, size=rhs_shape),
)
for i in range(100)
]
dual_op_circuit = func.compile(dual_op_inputset, configuration)
circuit = implementation.compile(inputset, configuration)
lhs_sample, rhs_sample = np.random.randint(
minimum, maximum, size=lhs_shape
), np.random.randint(minimum, maximum, size=rhs_shape)
sample = [
np.random.randint(minimum, maximum, size=lhs_shape),
np.random.randint(minimum, maximum, size=rhs_shape),
]
helpers.check_execution(dual_op_circuit, func, [lhs_sample, rhs_sample], retries=3)
helpers.check_execution(circuit, implementation, sample, retries=3)
@pytest.mark.parametrize("bitwidth", [4, 10])
@pytest.mark.parametrize("bit_width", [4, 10])
@pytest.mark.parametrize("signed", [True, False])
def test_matmul_zero(bitwidth, signed, helpers):
def test_zero_matmul(bit_width, signed, helpers):
"""
Test matmul.
Test matmul where one of the operators is all zeros.
"""
lhs_shape = (2, 1)
rhs_shape = (1, 2)
range_lhs = (-(2 ** (bitwidth - 1)), 2 ** (bitwidth - 1) - 1) if signed else (0, 2**bitwidth)
configuration = helpers.configuration()
# Matmul of two encrypted matrices
lhs_shape = (2, 1)
rhs_shape = (1, 2)
bounds = (-(2 ** (bit_width - 1)), 2 ** (bit_width - 1) - 1) if signed else (0, 2**bit_width)
minimum, maximum = bounds
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def enc_function_xy(x, y):
def function(x, y):
return x * y
dual_op_inputset = [
inputset = [
(
np.random.randint(range_lhs[0], range_lhs[1], size=lhs_shape),
np.random.randint(minimum, maximum, size=lhs_shape),
np.zeros(rhs_shape, dtype=np.int64),
)
for i in range(100)
]
dual_op_circuit = enc_function_xy.compile(dual_op_inputset, configuration)
circuit = function.compile(inputset, configuration)
lhs_sample, rhs_sample = np.random.randint(
range_lhs[0], range_lhs[1], size=lhs_shape
), np.zeros(rhs_shape, dtype=np.int64)
sample = [
np.random.randint(minimum, maximum, size=lhs_shape),
np.zeros(rhs_shape, dtype=np.int64),
]
helpers.check_execution(dual_op_circuit, enc_function_xy, [lhs_sample, rhs_sample], retries=3)
helpers.check_execution(circuit, function, sample, retries=3)