feat(mlir): implement mlir conversion of constant indexing

This commit is contained in:
Umut
2021-11-29 10:23:05 +03:00
parent 926597c3f6
commit ac74e94e13
4 changed files with 335 additions and 22 deletions

View File

@@ -7,11 +7,13 @@
from typing import Any, Dict, List, Tuple, cast
import numpy
from mlir.dialects import arith
from mlir.dialects import arith, tensor
from mlir.ir import (
ArrayAttr,
Attribute,
Context,
DenseElementsAttr,
IndexType,
IntegerAttr,
IntegerType,
OpResult,
@@ -21,12 +23,14 @@ from zamalang.dialects import hlfhe, hlfhelinalg
from ..data_types import Integer
from ..debugging import assert_true
from ..helpers.indexing_helpers import determine_new_dimension_size
from ..operator_graph import OPGraph
from ..representation.intermediate import (
Add,
Constant,
Dot,
GenericFunction,
IndexConstant,
IntermediateNode,
MatMul,
Mul,
@@ -101,6 +105,8 @@ class IntermediateNodeConverter:
str: textual MLIR representation corresponding to self.node
"""
# pylint: disable=too-many-branches
if isinstance(self.node, Add):
result = self.convert_add()
@@ -113,6 +119,9 @@ class IntermediateNodeConverter:
elif isinstance(self.node, GenericFunction):
result = self.convert_generic_function(additional_conversion_info)
elif isinstance(self.node, IndexConstant):
result = self.convert_index_constant()
elif isinstance(self.node, MatMul):
result = self.convert_matmul()
@@ -126,6 +135,8 @@ class IntermediateNodeConverter:
# this branch is not covered as unsupported opeations fail on check mlir compatibility
raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet")
# pylint: enable=too-many-branches
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
self.nodes_to_mlir_names[self.node] = mlir_name
@@ -319,6 +330,98 @@ class IntermediateNodeConverter:
return result
def convert_index_constant(self) -> OpResult:
"""Convert a IndexConstant node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 1)
assert_true(len(self.node.outputs) == 1)
tensor_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
pred = self.preds[0]
input_value = cast(TensorValue, self.node.inputs[0])
input_shape = input_value.shape
index = cast(IndexConstant, self.node).index
index_str = self.node.text_for_formatting([""], 0)
index_type = IndexType.parse("index")
if len(index) == len(input_shape) and all(isinstance(i, int) for i in index):
indices = []
for value, dimension_size in zip(index, input_shape):
assert isinstance(value, int) # mypy
attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size)
indices.append(arith.ConstantOp(index_type, attr).result)
return tensor.ExtractOp(tensor_type, pred, indices).result
offsets = []
sizes = []
strides = []
can_be_converted = True
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
if isinstance(indexing_element, int):
size = 1
stride = 1
offset = (
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
)
elif isinstance(indexing_element, slice):
size = determine_new_dimension_size(
indexing_element,
dimension_size,
dimension,
input_shape,
index_str,
)
if size == 1:
can_be_converted = False
break
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
offset = (
(
indexing_element.start
if indexing_element.start >= 0
else indexing_element.start + dimension_size
)
if isinstance(indexing_element.start, int)
else (0 if stride > 0 else dimension_size - 1)
)
else: # pragma: no cover
# this branch is impossible to reach with all the previous checks
# but let's keep it as an extra measure
can_be_converted = False
break
offsets.append(offset)
sizes.append(size)
strides.append(stride)
if not can_be_converted:
raise NotImplementedError(
f"Indexing of {input_value} with {index_str} cannot be converted to MLIR yet",
)
return tensor.ExtractSliceOp(
tensor_type,
pred,
[],
[],
[],
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
).result
def convert_matmul(self) -> OpResult:
"""Convert a MatMul node to its corresponding MLIR representation.

View File

@@ -99,7 +99,6 @@ def check_node_compatibility_with_mlir(
elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing
assert_true(len(outputs) == 1)
return "indexing is not supported for the time being"
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
assert_true(len(inputs) == 2)

View File

@@ -7,7 +7,7 @@ import numpy
import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.integers import Integer, UnsignedInteger
from concrete.common.data_types.integers import Integer, SignedInteger, UnsignedInteger
from concrete.common.debugging import draw_graph, format_operation_graph
from concrete.common.extensions.multi_table import MultiLookupTable
from concrete.common.extensions.table import LookupTable
@@ -1403,24 +1403,6 @@ return %2
""".strip() # noqa: E501
),
),
pytest.param(
lambda x: x[0],
{"x": EncryptedTensor(Integer(3, is_signed=True), shape=(2, 2))},
[(numpy.random.randint(-4, 2 ** 2, size=(2, 2)),) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<int3, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%1 = %0[0] # EncryptedTensor<int3, shape=(2,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ indexing is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
pytest.param(
no_fuse_unhandled,
{"x": EncryptedScalar(Integer(2, False)), "y": EncryptedScalar(Integer(2, False))},
@@ -1539,6 +1521,52 @@ def test_fail_compile(function, parameters, inputset, match, default_compilation
assert str(excinfo.value) == match, str(excinfo.value)
@pytest.mark.parametrize(
"function,parameters,inputset,match",
[
pytest.param(
lambda x: (x * 1.5)[0, 1],
{"x": EncryptedTensor(SignedInteger(3), shape=(2, 2))},
[(numpy.random.randint(-4, 3, size=(2, 2)),) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<int3, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%1 = 1.5 # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%2 = mul(%0, %1) # EncryptedTensor<float64, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported
%3 = %2[0, 1] # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer outputs are supported
return %3
""".strip() # noqa: E501
),
),
],
)
def test_fail_compile_while_fusing_is_disabled(
function, parameters, inputset, match, default_compilation_configuration
):
"""Test compile_numpy_function without fusing and with failing inputs"""
configuration_to_use = deepcopy(default_compilation_configuration)
configuration_to_use.enable_topological_optimizations = False
with pytest.raises(RuntimeError) as excinfo:
compile_numpy_function(
function,
parameters,
inputset,
configuration_to_use,
)
assert str(excinfo.value) == match, str(excinfo.value)
def test_small_inputset_no_fail():
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
compile_numpy_function_into_op_graph_and_measure_bounds(

View File

@@ -5,7 +5,10 @@ import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.numpy import compile_numpy_function_into_op_graph_and_measure_bounds
from concrete.numpy import (
compile_numpy_function,
compile_numpy_function_into_op_graph_and_measure_bounds,
)
@pytest.mark.parametrize(
@@ -595,3 +598,183 @@ def test_invalid_constant_indexing_with_numpy_values(
except Exception as error:
assert str(error) == expected_error_message
raise
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
[
pytest.param(
lambda x: x[0],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[(np.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([4, 2, 6],),
4,
),
pytest.param(
lambda x: x[-1],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[(np.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
([4, 2, 6],),
6,
),
pytest.param(
lambda x: x[:3],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
([4, 2, 6, 1],),
[4, 2, 6],
),
pytest.param(
lambda x: x[2:],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
([4, 2, 6, 1],),
[6, 1],
),
pytest.param(
lambda x: x[1:3],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
([4, 2, 6, 1],),
[2, 6],
),
pytest.param(
lambda x: x[::2],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
([4, 2, 6, 1],),
[4, 6],
),
pytest.param(
lambda x: x[::-1],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
},
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
([4, 2, 6, 1],),
[1, 6, 2, 4],
),
pytest.param(
lambda x: x[1, 0],
{
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
},
[(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
21,
),
pytest.param(
lambda x: x[:, :],
{
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
},
[(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
[[11, 12], [21, 22], [31, 32]],
),
pytest.param(
lambda x: x[0, :],
{
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
},
[(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
[11, 12],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: x[:, 0],
{
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
},
[(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
[11, 21, 31],
marks=pytest.mark.xfail(strict=True),
),
],
)
def test_constant_indexing_run_correctness(
function,
parameters,
inputset,
test_input,
expected_output,
default_compilation_configuration,
):
"""Test correctness of results when running a compiled function with tensor operators"""
circuit = compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
numpy_test_input = tuple(
item if isinstance(item, int) else np.array(item, dtype=np.uint8) for item in test_input
)
output = circuit.run(*numpy_test_input)
expected = np.array(expected_output, dtype=np.uint8)
assert np.array_equal(
output, expected
), f"""
Actual Output
=============
{output}
Expected Output
===============
{expected}
"""
@pytest.mark.parametrize(
"function,parameters,inputset,match",
[
pytest.param(
lambda x: x[0:1],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[(np.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
(
"Indexing of EncryptedTensor<uint3, shape=(3,)> with [0:1] "
"cannot be converted to MLIR yet"
),
),
],
)
def test_constant_indexing_failed_compilation(
function,
parameters,
inputset,
match,
default_compilation_configuration,
):
"""Test compilation failures of compiled function with constant indexing"""
with pytest.raises(RuntimeError) as excinfo:
compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
assert str(excinfo.value) == match, str(excinfo.value)