fix: allow signed matmuls

closes #957
This commit is contained in:
Arthur Meyre
2021-12-17 10:41:00 +01:00
parent 6697b64147
commit 31a3bdcf2e
2 changed files with 59 additions and 26 deletions

View File

@@ -102,8 +102,6 @@ def check_node_compatibility_with_mlir(
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
assert_true(len(inputs) == 2)
if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]):
return "only unsigned integer matrix multiplication is supported"
else: # pragma: no cover
assert_not_reached("Non IntermediateNode object in the OPGraph")

View File

@@ -14,6 +14,7 @@ from concrete.common.extensions.table import LookupTable
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
from concrete.numpy.compile import (
FHECircuit,
compile_numpy_function,
compile_numpy_function_into_op_graph_and_measure_bounds,
)
@@ -1353,6 +1354,11 @@ def test_compile_and_run_constant_dot_correctness(
(5, 5),
(0, 4),
),
pytest.param(
(3, 2),
(2, 3),
(-4, 3),
),
],
)
def test_compile_and_run_matmul_correctness(
@@ -1362,10 +1368,16 @@ def test_compile_and_run_matmul_correctness(
low, high = input_range
check_mod = low < 0 or high < 0
max_abs = max(abs(low), abs(high))
# Inputset for x as lhs of matmul
lhs_inputset = [
numpy.zeros(lhs_shape, dtype=numpy.uint32),
numpy.ones(lhs_shape, dtype=numpy.uint32) * high,
]
# Inputset for x as rhs of matmul
rhs_inputset = [
numpy.zeros(rhs_shape, dtype=numpy.uint32),
numpy.ones(rhs_shape, dtype=numpy.uint32) * high,
@@ -1377,6 +1389,32 @@ def test_compile_and_run_matmul_correctness(
left_constant = numpy.random.randint(low, high + 1, size=lhs_shape)
right_constant = numpy.random.randint(low, high + 1, size=rhs_shape)
# Generate worst case inputsets for bit widths, replacing negative values by 0 and putting
# the max value elsewhere, and then doing the same for positive values
rhs_inputset.extend(
[
numpy.where(right_constant < 0, 0, max_abs),
numpy.where(right_constant > 0, 0, max_abs),
]
)
lhs_inputset.extend(
[
numpy.where(left_constant < 0, 0, max_abs),
numpy.where(left_constant > 0, 0, max_abs),
]
)
# Keep inputset positive
rhs_inputset = [numpy.clip(val, 0, high) for val in rhs_inputset]
lhs_inputset = [numpy.clip(val, 0, high) for val in lhs_inputset]
def get_output_mod(circuit: FHECircuit):
assert len(circuit.op_graph.output_nodes) == 1
assert isinstance(
output_dtype := circuit.op_graph.get_ordered_outputs()[0].outputs[0].dtype, Integer
)
return 2 ** output_dtype.bit_width
def using_operator_left(x):
return x @ right_constant
@@ -1414,13 +1452,28 @@ def test_compile_and_run_matmul_correctness(
default_compilation_configuration,
)
lhs_arg = numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8)
check_array_equality(operator_left_circuit.run(lhs_arg), using_operator_left(lhs_arg))
check_array_equality(function_left_circuit.run(lhs_arg), using_function_left(lhs_arg))
def check_result(circuit: FHECircuit, func, arg):
# Stay positive for input to FHE circuit
arg = numpy.clip(arg, 0, high).astype(numpy.uint8)
rhs_arg = numpy.random.randint(low, high + 1, size=rhs_shape, dtype=numpy.uint8)
check_array_equality(operator_right_circuit.run(rhs_arg), using_operator_right(rhs_arg))
check_array_equality(function_right_circuit.run(rhs_arg), using_function_right(rhs_arg))
circuit_output = circuit.run(arg)
func_output = func(arg)
if check_mod:
output_mod = get_output_mod(circuit)
circuit_output %= output_mod
func_output %= output_mod
check_array_equality(circuit_output, func_output)
arg = numpy.random.randint(low, high + 1, size=lhs_shape)
check_result(operator_left_circuit, using_operator_left, arg)
check_result(function_left_circuit, using_function_left, arg)
arg = numpy.random.randint(low, high + 1, size=rhs_shape)
check_result(operator_right_circuit, using_operator_right, arg)
check_result(function_right_circuit, using_function_right, arg)
@pytest.mark.parametrize(
@@ -1643,24 +1696,6 @@ return %9
""".strip() # noqa: E501
),
),
pytest.param(
lambda x: x @ -numpy.ones(shape=(2, 3), dtype=numpy.int32),
{"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2))},
[numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
%1 = [[-1 -1 -1] [-1 -1 -1]] # ClearTensor<int2, shape=(2, 3)>
%2 = matmul(%0, %1) # EncryptedTensor<int5, shape=(3, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer matrix multiplication is supported
return %2
""".strip() # noqa: E501
),
),
pytest.param(
lambda x: numpy.transpose(x),
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},