feat(compilation): implement MLIR conversion of constant arrays

This commit is contained in:
Umut
2021-10-06 16:47:16 +03:00
parent 674b86cf62
commit 5fce0d2920
4 changed files with 152 additions and 25 deletions

View File

@@ -11,18 +11,19 @@ from typing import cast
# pylint: disable=no-name-in-module,no-member
import numpy as np
from mlir.dialects import std as std_dialect
from mlir.ir import DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
from mlir.ir import Attribute, DenseElementsAttr, IntegerAttr, IntegerType, RankedTensorType
from zamalang.dialects import hlfhe
from ...common.data_types.integers import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
value_is_clear_tensor_integer,
value_is_encrypted_scalar_unsigned_integer,
value_is_encrypted_tensor_integer,
)
from ..data_types.integers import Integer
from ..debugging.custom_assert import custom_assert
from ..representation.intermediate import Add, ArbitraryFunction, Constant, Dot, Mul, Sub
from ..values import TensorValue
def add(node, preds, ir_to_mlir_node, ctx):
@@ -123,14 +124,44 @@ def _mul_eint_int(node, preds, ir_to_mlir_node, ctx):
def constant(node, _, __, ctx):
"""Convert a constant inputs."""
if not value_is_clear_scalar_integer(node.outputs[0]):
raise TypeError("Don't support non-integer constants")
dtype = cast(Integer, node.outputs[0].dtype)
if dtype.is_signed:
raise TypeError("Don't support signed constant integer")
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, node.constant_data)).result
"""Convert a constant input."""
value = node.outputs[0]
if value_is_clear_scalar_integer(value):
value = cast(TensorValue, value)
dtype = cast(Integer, value.dtype)
if dtype.is_signed:
raise TypeError("Don't support signed constant integer")
data = node.constant_data
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
return std_dialect.ConstantOp(int_type, IntegerAttr.get(int_type, data)).result
if value_is_clear_tensor_integer(value):
value = cast(TensorValue, value)
dtype = cast(Integer, value.dtype)
if dtype.is_signed:
raise TypeError("Don't support signed constant integer tensor")
data = node.constant_data
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
vec_type = RankedTensorType.get(value.shape, int_type)
# usage of `Attribute.parse` is the result of some limitations in the MLIR module
# provided by LLVM
# `DenseElementsAttr` should have been used instead but it's impossible to assign
# custom bit-widths using it (e.g., uint5)
# since we coudn't create a `DenseElementsAttr` with a custom bit width using python api
# we use `Attribute.parse` to let the underlying library do it by itself
value_attr = Attribute.parse(f"dense<{str(data.tolist())}> : {vec_type}")
return std_dialect.ConstantOp(vec_type, value_attr).result
raise TypeError(f"Don't support {value} constants")
def apply_lut(node, preds, ir_to_mlir_node, ctx):

View File

@@ -4,7 +4,7 @@ import pytest
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.mlir.converters import add, apply_lut, constant, dot, mul, sub
from concrete.common.values import ClearScalar, EncryptedScalar
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar
class MockNode:
@@ -30,14 +30,19 @@ def test_failing_converter(converter):
def test_fail_non_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support non-integer constants"):
with pytest.raises(TypeError, match=r"Don't support .* constants"):
constant(MockNode(outputs=[ClearScalar(Float(32))]), None, None, None)
with pytest.raises(TypeError, match=r"Don't support .* constants"):
constant(MockNode(outputs=[ClearTensor(Float(32), shape=(2,))]), None, None, None)
def test_fail_signed_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
constant(MockNode(outputs=[ClearScalar(Integer(8, True))]), None, None, None)
with pytest.raises(TypeError, match=r"Don't support signed constant integer tensor"):
constant(MockNode(outputs=[ClearTensor(Integer(8, True), shape=(2,))]), None, None, None)
@pytest.mark.parametrize(

View File

@@ -267,6 +267,36 @@ def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
compiler.round_trip(mlir_result)
def test_mlir_converter_dot_vector_and_constant():
"""Test the conversion to MLIR by calling the parser from the compiler"""
def left_dot_with_constant(x):
return numpy.dot(x, numpy.array([1, 2]))
def right_dot_with_constant(x):
return numpy.dot(numpy.array([1, 2]), x)
left_graph = compile_numpy_function_into_op_graph(
left_dot_with_constant,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
)
left_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
left_mlir = left_converter.convert(left_graph)
right_graph = compile_numpy_function_into_op_graph(
right_dot_with_constant,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2,))},
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
)
right_converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
right_mlir = right_converter.convert(right_graph)
# testing that this doesn't raise an error
compiler.round_trip(left_mlir)
compiler.round_trip(right_mlir)
def test_concrete_encrypted_integer_to_mlir_type():
"""Test conversion of EncryptedScalar into MLIR"""
value = EncryptedScalar(Integer(7, is_signed=False))

View File

@@ -161,11 +161,11 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
(0, 5),
),
pytest.param(
8,
6,
(0, 4),
),
pytest.param(
16,
10,
(0, 3),
),
],
@@ -173,18 +173,22 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
def test_compile_and_run_dot_correctness(size, input_range):
"""Test correctness of results when running a compiled function"""
def data_gen(input_range, size):
for _ in range(1000):
low, high = input_range
args = [
numpy.array([random.randint(low, high) for _ in range(size)]) for __ in range(2)
]
low, high = input_range
shape = (size,)
yield args
inputset = [
(numpy.zeros(shape, dtype=numpy.uint32), numpy.zeros(shape, dtype=numpy.uint32)),
(
numpy.ones(shape, dtype=numpy.uint32) * high,
numpy.ones(shape, dtype=numpy.uint32) * high,
),
]
for _ in range(8):
inputset.append((numpy.random.randint(low, high + 1), numpy.random.randint(low, high + 1)))
function_parameters = {
"x": EncryptedTensor(Integer(64, False), (size,)),
"y": ClearTensor(Integer(64, False), (size,)),
"x": EncryptedTensor(Integer(64, False), shape),
"y": ClearTensor(Integer(64, False), shape),
}
def function(x, y):
@@ -193,14 +197,71 @@ def test_compile_and_run_dot_correctness(size, input_range):
compiler_engine = compile_numpy_function(
function,
function_parameters,
data_gen(input_range, size),
inputset,
)
low, high = input_range
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
assert compiler_engine.run(*args) == function(*args)
@pytest.mark.parametrize(
"size,input_range",
[
pytest.param(
1,
(0, 8),
),
pytest.param(
4,
(0, 5),
),
pytest.param(
6,
(0, 4),
),
pytest.param(
10,
(0, 3),
),
],
)
def test_compile_and_run_constant_dot_correctness(size, input_range):
"""Test correctness of results when running a compiled function"""
low, high = input_range
shape = (size,)
inputset = [
(numpy.zeros(shape, dtype=numpy.uint32),),
(numpy.ones(shape, dtype=numpy.uint32) * high,),
]
for _ in range(8):
inputset.append((numpy.random.randint(low, high + 1),))
constant = numpy.random.randint(low, high + 1, size=shape)
def left(x):
return numpy.dot(x, constant)
def right(x):
return numpy.dot(constant, x)
left_circuit = compile_numpy_function(
left,
{"x": EncryptedTensor(Integer(64, False), shape)},
inputset,
)
right_circuit = compile_numpy_function(
left,
{"x": EncryptedTensor(Integer(64, False), shape)},
inputset,
)
args = (numpy.random.randint(low, high + 1, size=shape).tolist(),)
assert left_circuit.run(*args) == left(*args)
assert right_circuit.run(*args) == right(*args)
def test_compile_function_with_direct_tlu():
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""