mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
@@ -102,8 +102,6 @@ def check_node_compatibility_with_mlir(
|
||||
|
||||
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
|
||||
assert_true(len(inputs) == 2)
|
||||
if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]):
|
||||
return "only unsigned integer matrix multiplication is supported"
|
||||
|
||||
else: # pragma: no cover
|
||||
assert_not_reached("Non IntermediateNode object in the OPGraph")
|
||||
|
||||
@@ -14,6 +14,7 @@ from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
from concrete.numpy.compile import (
|
||||
FHECircuit,
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds,
|
||||
)
|
||||
@@ -1353,6 +1354,11 @@ def test_compile_and_run_constant_dot_correctness(
|
||||
(5, 5),
|
||||
(0, 4),
|
||||
),
|
||||
pytest.param(
|
||||
(3, 2),
|
||||
(2, 3),
|
||||
(-4, 3),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_matmul_correctness(
|
||||
@@ -1362,10 +1368,16 @@ def test_compile_and_run_matmul_correctness(
|
||||
|
||||
low, high = input_range
|
||||
|
||||
check_mod = low < 0 or high < 0
|
||||
|
||||
max_abs = max(abs(low), abs(high))
|
||||
|
||||
# Inputset for x as lhs of matmul
|
||||
lhs_inputset = [
|
||||
numpy.zeros(lhs_shape, dtype=numpy.uint32),
|
||||
numpy.ones(lhs_shape, dtype=numpy.uint32) * high,
|
||||
]
|
||||
# Inputset for x as rhs of matmul
|
||||
rhs_inputset = [
|
||||
numpy.zeros(rhs_shape, dtype=numpy.uint32),
|
||||
numpy.ones(rhs_shape, dtype=numpy.uint32) * high,
|
||||
@@ -1377,6 +1389,32 @@ def test_compile_and_run_matmul_correctness(
|
||||
left_constant = numpy.random.randint(low, high + 1, size=lhs_shape)
|
||||
right_constant = numpy.random.randint(low, high + 1, size=rhs_shape)
|
||||
|
||||
# Generate worst case inputsets for bit widths, replacing negative values by 0 and putting
|
||||
# the max value elsewhere, and then doing the same for positive values
|
||||
rhs_inputset.extend(
|
||||
[
|
||||
numpy.where(right_constant < 0, 0, max_abs),
|
||||
numpy.where(right_constant > 0, 0, max_abs),
|
||||
]
|
||||
)
|
||||
lhs_inputset.extend(
|
||||
[
|
||||
numpy.where(left_constant < 0, 0, max_abs),
|
||||
numpy.where(left_constant > 0, 0, max_abs),
|
||||
]
|
||||
)
|
||||
|
||||
# Keep inputset positive
|
||||
rhs_inputset = [numpy.clip(val, 0, high) for val in rhs_inputset]
|
||||
lhs_inputset = [numpy.clip(val, 0, high) for val in lhs_inputset]
|
||||
|
||||
def get_output_mod(circuit: FHECircuit):
|
||||
assert len(circuit.op_graph.output_nodes) == 1
|
||||
assert isinstance(
|
||||
output_dtype := circuit.op_graph.get_ordered_outputs()[0].outputs[0].dtype, Integer
|
||||
)
|
||||
return 2 ** output_dtype.bit_width
|
||||
|
||||
def using_operator_left(x):
|
||||
return x @ right_constant
|
||||
|
||||
@@ -1414,13 +1452,28 @@ def test_compile_and_run_matmul_correctness(
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
lhs_arg = numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8)
|
||||
check_array_equality(operator_left_circuit.run(lhs_arg), using_operator_left(lhs_arg))
|
||||
check_array_equality(function_left_circuit.run(lhs_arg), using_function_left(lhs_arg))
|
||||
def check_result(circuit: FHECircuit, func, arg):
|
||||
# Stay positive for input to FHE circuit
|
||||
arg = numpy.clip(arg, 0, high).astype(numpy.uint8)
|
||||
|
||||
rhs_arg = numpy.random.randint(low, high + 1, size=rhs_shape, dtype=numpy.uint8)
|
||||
check_array_equality(operator_right_circuit.run(rhs_arg), using_operator_right(rhs_arg))
|
||||
check_array_equality(function_right_circuit.run(rhs_arg), using_function_right(rhs_arg))
|
||||
circuit_output = circuit.run(arg)
|
||||
func_output = func(arg)
|
||||
|
||||
if check_mod:
|
||||
output_mod = get_output_mod(circuit)
|
||||
|
||||
circuit_output %= output_mod
|
||||
func_output %= output_mod
|
||||
|
||||
check_array_equality(circuit_output, func_output)
|
||||
|
||||
arg = numpy.random.randint(low, high + 1, size=lhs_shape)
|
||||
check_result(operator_left_circuit, using_operator_left, arg)
|
||||
check_result(function_left_circuit, using_function_left, arg)
|
||||
|
||||
arg = numpy.random.randint(low, high + 1, size=rhs_shape)
|
||||
check_result(operator_right_circuit, using_operator_right, arg)
|
||||
check_result(function_right_circuit, using_function_right, arg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -1643,24 +1696,6 @@ return %9
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x @ -numpy.ones(shape=(2, 3), dtype=numpy.int32),
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2))},
|
||||
[numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = [[-1 -1 -1] [-1 -1 -1]] # ClearTensor<int2, shape=(2, 3)>
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<int5, shape=(3, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer matrix multiplication is supported
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.transpose(x),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
|
||||
Reference in New Issue
Block a user