mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(mlir): implement MLIR conversion for clear @ encrypted
This commit is contained in:
@@ -432,24 +432,20 @@ class IntermediateNodeConverter:
|
||||
assert_true(len(self.node.inputs) == 2)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
if self.all_of_the_inputs_are_encrypted or self.node.inputs[0].is_clear:
|
||||
if self.all_of_the_inputs_are_encrypted:
|
||||
lhs = self.node.inputs[0]
|
||||
rhs = self.node.inputs[1]
|
||||
|
||||
additional_error_info = (
|
||||
" (notice the encrypted value is in the right hand side which is not supported)"
|
||||
if self.node.inputs[0].is_clear
|
||||
else ""
|
||||
)
|
||||
raise NotImplementedError(
|
||||
f"Matrix multiplication between {lhs} and {rhs} cannot be converted to MLIR yet"
|
||||
f"{additional_error_info}",
|
||||
f"Matrix multiplication between {lhs} and {rhs} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
preds = self.preds
|
||||
|
||||
result = hlfhelinalg.MatMulEintIntOp(resulting_type, *preds).result
|
||||
if self.node.inputs[0].is_clear:
|
||||
result = hlfhelinalg.MatMulIntEintOp(resulting_type, *preds).result
|
||||
else:
|
||||
result = hlfhelinalg.MatMulEintIntOp(resulting_type, *preds).result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -66,18 +66,26 @@ from concrete.numpy import compile_numpy_function
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.ones(shape=(2, 3), dtype=numpy.uint32) @ x,
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2))},
|
||||
[numpy.random.randint(0, 2 ** 3, size=(3, 2)) for i in range(10)]
|
||||
+ [numpy.array([[7, 7], [7, 7], [7, 7]])],
|
||||
lambda x, y: x @ y,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
"y": EncryptedTensor(UnsignedInteger(3), shape=(2, 1)),
|
||||
},
|
||||
[
|
||||
(
|
||||
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
|
||||
numpy.random.randint(0, 2 ** 3, size=(2, 1)),
|
||||
)
|
||||
for i in range(10)
|
||||
]
|
||||
+ [(numpy.array([[7, 7], [7, 7], [7, 7]]), numpy.array([[7], [7]]))],
|
||||
NotImplementedError,
|
||||
"Matrix multiplication "
|
||||
"between "
|
||||
"ClearTensor<uint6, shape=(2, 3)> "
|
||||
"EncryptedTensor<uint7, shape=(3, 2)> "
|
||||
"and "
|
||||
"EncryptedTensor<uint5, shape=(3, 2)> "
|
||||
"cannot be converted to MLIR yet "
|
||||
"(notice the encrypted value is in the right hand side which is not supported)",
|
||||
"EncryptedTensor<uint7, shape=(2, 1)> "
|
||||
"cannot be converted to MLIR yet",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1343,37 +1343,65 @@ def test_compile_and_run_matmul_correctness(
|
||||
|
||||
low, high = input_range
|
||||
|
||||
inputset = [
|
||||
lhs_inputset = [
|
||||
numpy.zeros(lhs_shape, dtype=numpy.uint32),
|
||||
numpy.ones(lhs_shape, dtype=numpy.uint32) * high,
|
||||
]
|
||||
rhs_inputset = [
|
||||
numpy.zeros(rhs_shape, dtype=numpy.uint32),
|
||||
numpy.ones(rhs_shape, dtype=numpy.uint32) * high,
|
||||
]
|
||||
for _ in range(8):
|
||||
inputset.append(numpy.random.randint(low, high + 1, size=lhs_shape))
|
||||
lhs_inputset.append(numpy.random.randint(low, high + 1, size=lhs_shape))
|
||||
rhs_inputset.append(numpy.random.randint(low, high + 1, size=rhs_shape))
|
||||
|
||||
constant = numpy.random.randint(low, high + 1, size=rhs_shape)
|
||||
left_constant = numpy.random.randint(low, high + 1, size=lhs_shape)
|
||||
right_constant = numpy.random.randint(low, high + 1, size=rhs_shape)
|
||||
|
||||
def using_operator(x):
|
||||
return x @ constant
|
||||
def using_operator_left(x):
|
||||
return x @ right_constant
|
||||
|
||||
def using_function(x):
|
||||
return numpy.matmul(x, constant)
|
||||
def using_function_left(x):
|
||||
return numpy.matmul(x, right_constant)
|
||||
|
||||
operator_circuit = compile_numpy_function(
|
||||
using_operator,
|
||||
def using_operator_right(x):
|
||||
return left_constant @ x
|
||||
|
||||
def using_function_right(x):
|
||||
return numpy.matmul(left_constant, x)
|
||||
|
||||
operator_left_circuit = compile_numpy_function(
|
||||
using_operator_left,
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)},
|
||||
inputset,
|
||||
lhs_inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
function_circuit = compile_numpy_function(
|
||||
using_function,
|
||||
function_left_circuit = compile_numpy_function(
|
||||
using_function_left,
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)},
|
||||
inputset,
|
||||
lhs_inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
operator_right_circuit = compile_numpy_function(
|
||||
using_operator_right,
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), rhs_shape)},
|
||||
rhs_inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
function_right_circuit = compile_numpy_function(
|
||||
using_function_right,
|
||||
{"x": EncryptedTensor(UnsignedInteger(3), rhs_shape)},
|
||||
rhs_inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = (numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8),)
|
||||
check_array_equality(operator_circuit.run(*args), using_operator(*args))
|
||||
check_array_equality(function_circuit.run(*args), using_function(*args))
|
||||
lhs_arg = numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8)
|
||||
check_array_equality(operator_left_circuit.run(lhs_arg), using_operator_left(lhs_arg))
|
||||
check_array_equality(function_left_circuit.run(lhs_arg), using_function_left(lhs_arg))
|
||||
|
||||
rhs_arg = numpy.random.randint(low, high + 1, size=rhs_shape, dtype=numpy.uint8)
|
||||
check_array_equality(operator_right_circuit.run(rhs_arg), using_operator_right(rhs_arg))
|
||||
check_array_equality(function_right_circuit.run(rhs_arg), using_function_right(rhs_arg))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user