mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor(frontend-python): matmul and dot tests to be consistent with the rest of the tests
This commit is contained in:
@@ -12,9 +12,9 @@ from concrete import fhe
|
||||
"size",
|
||||
[1, 4, 6, 10],
|
||||
)
|
||||
def test_dot(size, helpers):
|
||||
def test_constant_dot(size, helpers):
|
||||
"""
|
||||
Test dot.
|
||||
Test dot where one of the operators is a constant.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
@@ -23,7 +23,7 @@ def test_dot(size, helpers):
|
||||
cst = np.random.randint(0, bound, size=(size,))
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
def dot_enc_enc_function(x):
|
||||
def left_function(x):
|
||||
return np.dot(x, cst)
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
@@ -36,55 +36,50 @@ def test_dot(size, helpers):
|
||||
|
||||
inputset = [np.random.randint(0, bound, size=(size,)) for i in range(100)]
|
||||
|
||||
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
|
||||
left_function_circuit = left_function.compile(inputset, configuration)
|
||||
right_function_circuit = right_function.compile(inputset, configuration)
|
||||
method_circuit = method.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, bound, size=(size,))
|
||||
|
||||
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample)
|
||||
helpers.check_execution(left_function_circuit, left_function, sample)
|
||||
helpers.check_execution(right_function_circuit, right_function, sample)
|
||||
helpers.check_execution(method_circuit, method, sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"size",
|
||||
[1, 10],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"bitwidth",
|
||||
[2, 6],
|
||||
)
|
||||
@pytest.mark.parametrize("size", [1, 10])
|
||||
@pytest.mark.parametrize("bit_width", [2, 6])
|
||||
@pytest.mark.parametrize("signed", [True, False])
|
||||
@pytest.mark.parametrize("negative_only", [True, False])
|
||||
def test_dot_enc_enc(size, bitwidth, negative_only, signed, helpers):
|
||||
@pytest.mark.parametrize("only_negative", [True, False])
|
||||
def test_dot(size, bit_width, only_negative, signed, helpers):
|
||||
"""
|
||||
Test dot.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
minv = 0 if not signed else -(2 ** (bitwidth - 1))
|
||||
minimum = 0 if not signed else -(2 ** (bit_width - 1))
|
||||
maximum = 2**bit_width if not signed else 2 ** (bit_width - 1)
|
||||
|
||||
# +1 since randint max is not inclusive
|
||||
maxv = 2**bitwidth if not signed else 2 ** (bitwidth - 1)
|
||||
if negative_only:
|
||||
maxv = 1
|
||||
if only_negative:
|
||||
maximum = 1
|
||||
|
||||
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def dot_enc_enc_function(x, y):
|
||||
def function(x, y):
|
||||
return np.dot(x, y)
|
||||
|
||||
inputset = [
|
||||
(np.random.randint(minv, maxv, size=(size,)), np.random.randint(minv, maxv, size=(size,)))
|
||||
(
|
||||
np.random.randint(minimum, maximum, size=(size,)),
|
||||
np.random.randint(minimum, maximum, size=(size,)),
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
dot_enc_enc_function_circuit = dot_enc_enc_function.compile(inputset, configuration)
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
sample = [
|
||||
np.random.randint(minv, maxv, size=(size,)),
|
||||
np.random.randint(minv, maxv, size=(size,)),
|
||||
np.random.randint(minimum, maximum, size=(size,)),
|
||||
np.random.randint(minimum, maximum, size=(size,)),
|
||||
]
|
||||
|
||||
helpers.check_execution(dot_enc_enc_function_circuit, dot_enc_enc_function, sample, retries=3)
|
||||
helpers.check_execution(circuit, function, sample, retries=3)
|
||||
|
||||
@@ -123,9 +123,9 @@ from concrete import fhe
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
|
||||
def test_constant_matmul(lhs_shape, rhs_shape, bounds, helpers):
|
||||
"""
|
||||
Test matmul.
|
||||
Test matmul where one of the operators is a constant.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
@@ -288,7 +288,7 @@ def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_matmul_enc_enc_and_clear(lhs_shape, rhs_shape, bounds, helpers):
|
||||
def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
|
||||
"""
|
||||
Test matmul.
|
||||
"""
|
||||
@@ -297,69 +297,63 @@ def test_matmul_enc_enc_and_clear(lhs_shape, rhs_shape, bounds, helpers):
|
||||
|
||||
minimum, maximum = bounds
|
||||
|
||||
# Matmul of clear values and encrypted matrices
|
||||
@fhe.compiler({"x": "encrypted", "y": "clear"})
|
||||
def lhs_operator_clear(x, y):
|
||||
def clear(x, y):
|
||||
return x @ y
|
||||
|
||||
# Matmul of two encrypted matrices
|
||||
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def enc_function_xy(x, y):
|
||||
def encrypted(x, y):
|
||||
return np.matmul(x, y)
|
||||
|
||||
# Put all the dual operand functions in a list
|
||||
|
||||
# FIXME: add lhs_operator_clear to this list to
|
||||
# re-enable the testing with clear values
|
||||
dual_operand_functions = [enc_function_xy]
|
||||
|
||||
# Compile each dual operand function and test it on random data
|
||||
for func in dual_operand_functions:
|
||||
dual_op_inputset = [
|
||||
for implementation in [clear, encrypted]:
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(minimum, maximum, size=lhs_shape),
|
||||
np.random.randint(minimum, maximum, size=rhs_shape),
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
dual_op_circuit = func.compile(dual_op_inputset, configuration)
|
||||
circuit = implementation.compile(inputset, configuration)
|
||||
|
||||
lhs_sample, rhs_sample = np.random.randint(
|
||||
minimum, maximum, size=lhs_shape
|
||||
), np.random.randint(minimum, maximum, size=rhs_shape)
|
||||
sample = [
|
||||
np.random.randint(minimum, maximum, size=lhs_shape),
|
||||
np.random.randint(minimum, maximum, size=rhs_shape),
|
||||
]
|
||||
|
||||
helpers.check_execution(dual_op_circuit, func, [lhs_sample, rhs_sample], retries=3)
|
||||
helpers.check_execution(circuit, implementation, sample, retries=3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bitwidth", [4, 10])
|
||||
@pytest.mark.parametrize("bit_width", [4, 10])
|
||||
@pytest.mark.parametrize("signed", [True, False])
|
||||
def test_matmul_zero(bitwidth, signed, helpers):
|
||||
def test_zero_matmul(bit_width, signed, helpers):
|
||||
"""
|
||||
Test matmul.
|
||||
Test matmul where one of the operators is all zeros.
|
||||
"""
|
||||
|
||||
lhs_shape = (2, 1)
|
||||
rhs_shape = (1, 2)
|
||||
range_lhs = (-(2 ** (bitwidth - 1)), 2 ** (bitwidth - 1) - 1) if signed else (0, 2**bitwidth)
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
# Matmul of two encrypted matrices
|
||||
lhs_shape = (2, 1)
|
||||
rhs_shape = (1, 2)
|
||||
|
||||
bounds = (-(2 ** (bit_width - 1)), 2 ** (bit_width - 1) - 1) if signed else (0, 2**bit_width)
|
||||
minimum, maximum = bounds
|
||||
|
||||
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def enc_function_xy(x, y):
|
||||
def function(x, y):
|
||||
return x * y
|
||||
|
||||
dual_op_inputset = [
|
||||
inputset = [
|
||||
(
|
||||
np.random.randint(range_lhs[0], range_lhs[1], size=lhs_shape),
|
||||
np.random.randint(minimum, maximum, size=lhs_shape),
|
||||
np.zeros(rhs_shape, dtype=np.int64),
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
dual_op_circuit = enc_function_xy.compile(dual_op_inputset, configuration)
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
lhs_sample, rhs_sample = np.random.randint(
|
||||
range_lhs[0], range_lhs[1], size=lhs_shape
|
||||
), np.zeros(rhs_shape, dtype=np.int64)
|
||||
sample = [
|
||||
np.random.randint(minimum, maximum, size=lhs_shape),
|
||||
np.zeros(rhs_shape, dtype=np.int64),
|
||||
]
|
||||
|
||||
helpers.check_execution(dual_op_circuit, enc_function_xy, [lhs_sample, rhs_sample], retries=3)
|
||||
helpers.check_execution(circuit, function, sample, retries=3)
|
||||
|
||||
Reference in New Issue
Block a user