feat(mlir): implement MLIR conversion for clear @ encrypted

This commit is contained in:
Umut
2021-11-25 12:56:11 +03:00
parent 868f9d9d6b
commit cad7d67e22
3 changed files with 66 additions and 34 deletions

View File

@@ -1343,37 +1343,65 @@ def test_compile_and_run_matmul_correctness(
low, high = input_range
inputset = [
lhs_inputset = [
numpy.zeros(lhs_shape, dtype=numpy.uint32),
numpy.ones(lhs_shape, dtype=numpy.uint32) * high,
]
rhs_inputset = [
numpy.zeros(rhs_shape, dtype=numpy.uint32),
numpy.ones(rhs_shape, dtype=numpy.uint32) * high,
]
for _ in range(8):
inputset.append(numpy.random.randint(low, high + 1, size=lhs_shape))
lhs_inputset.append(numpy.random.randint(low, high + 1, size=lhs_shape))
rhs_inputset.append(numpy.random.randint(low, high + 1, size=rhs_shape))
constant = numpy.random.randint(low, high + 1, size=rhs_shape)
left_constant = numpy.random.randint(low, high + 1, size=lhs_shape)
right_constant = numpy.random.randint(low, high + 1, size=rhs_shape)
def using_operator(x):
return x @ constant
def using_operator_left(x):
return x @ right_constant
def using_function(x):
return numpy.matmul(x, constant)
def using_function_left(x):
return numpy.matmul(x, right_constant)
operator_circuit = compile_numpy_function(
using_operator,
def using_operator_right(x):
return left_constant @ x
def using_function_right(x):
return numpy.matmul(left_constant, x)
operator_left_circuit = compile_numpy_function(
using_operator_left,
{"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)},
inputset,
lhs_inputset,
default_compilation_configuration,
)
function_circuit = compile_numpy_function(
using_function,
function_left_circuit = compile_numpy_function(
using_function_left,
{"x": EncryptedTensor(UnsignedInteger(3), lhs_shape)},
inputset,
lhs_inputset,
default_compilation_configuration,
)
operator_right_circuit = compile_numpy_function(
using_operator_right,
{"x": EncryptedTensor(UnsignedInteger(3), rhs_shape)},
rhs_inputset,
default_compilation_configuration,
)
function_right_circuit = compile_numpy_function(
using_function_right,
{"x": EncryptedTensor(UnsignedInteger(3), rhs_shape)},
rhs_inputset,
default_compilation_configuration,
)
args = (numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8),)
check_array_equality(operator_circuit.run(*args), using_operator(*args))
check_array_equality(function_circuit.run(*args), using_function(*args))
lhs_arg = numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8)
check_array_equality(operator_left_circuit.run(lhs_arg), using_operator_left(lhs_arg))
check_array_equality(function_left_circuit.run(lhs_arg), using_function_left(lhs_arg))
rhs_arg = numpy.random.randint(low, high + 1, size=rhs_shape, dtype=numpy.uint8)
check_array_equality(operator_right_circuit.run(rhs_arg), using_operator_right(rhs_arg))
check_array_equality(function_right_circuit.run(rhs_arg), using_function_right(rhs_arg))
@pytest.mark.parametrize(