mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(mlir): implement mlir conversion of constant indexing
This commit is contained in:
@@ -7,11 +7,13 @@
|
||||
from typing import Any, Dict, List, Tuple, cast
|
||||
|
||||
import numpy
|
||||
from mlir.dialects import arith
|
||||
from mlir.dialects import arith, tensor
|
||||
from mlir.ir import (
|
||||
ArrayAttr,
|
||||
Attribute,
|
||||
Context,
|
||||
DenseElementsAttr,
|
||||
IndexType,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
OpResult,
|
||||
@@ -21,12 +23,14 @@ from zamalang.dialects import hlfhe, hlfhelinalg
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..debugging import assert_true
|
||||
from ..helpers.indexing_helpers import determine_new_dimension_size
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import (
|
||||
Add,
|
||||
Constant,
|
||||
Dot,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
IntermediateNode,
|
||||
MatMul,
|
||||
Mul,
|
||||
@@ -101,6 +105,8 @@ class IntermediateNodeConverter:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
if isinstance(self.node, Add):
|
||||
result = self.convert_add()
|
||||
|
||||
@@ -113,6 +119,9 @@ class IntermediateNodeConverter:
|
||||
elif isinstance(self.node, GenericFunction):
|
||||
result = self.convert_generic_function(additional_conversion_info)
|
||||
|
||||
elif isinstance(self.node, IndexConstant):
|
||||
result = self.convert_index_constant()
|
||||
|
||||
elif isinstance(self.node, MatMul):
|
||||
result = self.convert_matmul()
|
||||
|
||||
@@ -126,6 +135,8 @@ class IntermediateNodeConverter:
|
||||
# this branch is not covered as unsupported opeations fail on check mlir compatibility
|
||||
raise NotImplementedError(f"{type(self.node)} nodes cannot be converted to MLIR yet")
|
||||
|
||||
# pylint: enable=too-many-branches
|
||||
|
||||
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
|
||||
self.nodes_to_mlir_names[self.node] = mlir_name
|
||||
@@ -319,6 +330,98 @@ class IntermediateNodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_index_constant(self) -> OpResult:
|
||||
"""Convert a IndexConstant node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
assert_true(len(self.node.inputs) == 1)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
tensor_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
pred = self.preds[0]
|
||||
|
||||
input_value = cast(TensorValue, self.node.inputs[0])
|
||||
input_shape = input_value.shape
|
||||
|
||||
index = cast(IndexConstant, self.node).index
|
||||
index_str = self.node.text_for_formatting([""], 0)
|
||||
|
||||
index_type = IndexType.parse("index")
|
||||
|
||||
if len(index) == len(input_shape) and all(isinstance(i, int) for i in index):
|
||||
indices = []
|
||||
for value, dimension_size in zip(index, input_shape):
|
||||
assert isinstance(value, int) # mypy
|
||||
attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size)
|
||||
indices.append(arith.ConstantOp(index_type, attr).result)
|
||||
return tensor.ExtractOp(tensor_type, pred, indices).result
|
||||
|
||||
offsets = []
|
||||
sizes = []
|
||||
strides = []
|
||||
|
||||
can_be_converted = True
|
||||
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
|
||||
|
||||
if isinstance(indexing_element, int):
|
||||
size = 1
|
||||
stride = 1
|
||||
offset = (
|
||||
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
|
||||
)
|
||||
|
||||
elif isinstance(indexing_element, slice):
|
||||
size = determine_new_dimension_size(
|
||||
indexing_element,
|
||||
dimension_size,
|
||||
dimension,
|
||||
input_shape,
|
||||
index_str,
|
||||
)
|
||||
if size == 1:
|
||||
can_be_converted = False
|
||||
break
|
||||
|
||||
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
|
||||
offset = (
|
||||
(
|
||||
indexing_element.start
|
||||
if indexing_element.start >= 0
|
||||
else indexing_element.start + dimension_size
|
||||
)
|
||||
if isinstance(indexing_element.start, int)
|
||||
else (0 if stride > 0 else dimension_size - 1)
|
||||
)
|
||||
|
||||
else: # pragma: no cover
|
||||
# this branch is impossible to reach with all the previous checks
|
||||
# but let's keep it as an extra measure
|
||||
can_be_converted = False
|
||||
break
|
||||
|
||||
offsets.append(offset)
|
||||
sizes.append(size)
|
||||
strides.append(stride)
|
||||
|
||||
if not can_be_converted:
|
||||
raise NotImplementedError(
|
||||
f"Indexing of {input_value} with {index_str} cannot be converted to MLIR yet",
|
||||
)
|
||||
|
||||
return tensor.ExtractSliceOp(
|
||||
tensor_type,
|
||||
pred,
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
|
||||
).result
|
||||
|
||||
def convert_matmul(self) -> OpResult:
|
||||
"""Convert a MatMul node to its corresponding MLIR representation.
|
||||
|
||||
|
||||
@@ -99,7 +99,6 @@ def check_node_compatibility_with_mlir(
|
||||
|
||||
elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing
|
||||
assert_true(len(outputs) == 1)
|
||||
return "indexing is not supported for the time being"
|
||||
|
||||
elif isinstance(node, intermediate.MatMul): # constraints for matrix multiplication
|
||||
assert_true(len(inputs) == 2)
|
||||
|
||||
@@ -7,7 +7,7 @@ import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.integers import Integer, UnsignedInteger
|
||||
from concrete.common.data_types.integers import Integer, SignedInteger, UnsignedInteger
|
||||
from concrete.common.debugging import draw_graph, format_operation_graph
|
||||
from concrete.common.extensions.multi_table import MultiLookupTable
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
@@ -1403,24 +1403,6 @@ return %2
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[0],
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=True), shape=(2, 2))},
|
||||
[(numpy.random.randint(-4, 2 ** 2, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<int3, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
|
||||
%1 = %0[0] # EncryptedTensor<int3, shape=(2,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ indexing is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_unhandled,
|
||||
{"x": EncryptedScalar(Integer(2, False)), "y": EncryptedScalar(Integer(2, False))},
|
||||
@@ -1539,6 +1521,52 @@ def test_fail_compile(function, parameters, inputset, match, default_compilation
|
||||
assert str(excinfo.value) == match, str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: (x * 1.5)[0, 1],
|
||||
{"x": EncryptedTensor(SignedInteger(3), shape=(2, 2))},
|
||||
[(numpy.random.randint(-4, 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<int3, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
|
||||
%1 = 1.5 # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%2 = mul(%0, %1) # EncryptedTensor<float64, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported
|
||||
%3 = %2[0, 1] # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer outputs are supported
|
||||
return %3
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fail_compile_while_fusing_is_disabled(
|
||||
function, parameters, inputset, match, default_compilation_configuration
|
||||
):
|
||||
"""Test compile_numpy_function without fusing and with failing inputs"""
|
||||
|
||||
configuration_to_use = deepcopy(default_compilation_configuration)
|
||||
configuration_to_use.enable_topological_optimizations = False
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
configuration_to_use,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == match, str(excinfo.value)
|
||||
|
||||
|
||||
def test_small_inputset_no_fail():
|
||||
"""Test function compile_numpy_function_into_op_graph with an unacceptably small inputset"""
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds(
|
||||
|
||||
@@ -5,7 +5,10 @@ import pytest
|
||||
|
||||
from concrete.common.data_types import UnsignedInteger
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
from concrete.numpy import (
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -595,3 +598,183 @@ def test_invalid_constant_indexing_with_numpy_values(
|
||||
except Exception as error:
|
||||
assert str(error) == expected_error_message
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x[0],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
|
||||
([4, 2, 6],),
|
||||
4,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[-1],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
|
||||
([4, 2, 6],),
|
||||
6,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[:3],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[4, 2, 6],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[2:],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[6, 1],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[1:3],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[2, 6],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[::2],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[4, 6],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[::-1],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(4,)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 3, size=(4,)),) for _ in range(10)],
|
||||
([4, 2, 6, 1],),
|
||||
[1, 6, 2, 4],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[1, 0],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
21,
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[:, :],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
[[11, 12], [21, 22], [31, 32]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[0, :],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
[11, 12],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[:, 0],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(6), shape=(3, 2)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 6, size=(3, 2)),) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
[11, 21, 31],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_indexing_run_correctness(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
test_input,
|
||||
expected_output,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with tensor operators"""
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
numpy_test_input = tuple(
|
||||
item if isinstance(item, int) else np.array(item, dtype=np.uint8) for item in test_input
|
||||
)
|
||||
|
||||
output = circuit.run(*numpy_test_input)
|
||||
expected = np.array(expected_output, dtype=np.uint8)
|
||||
|
||||
assert np.array_equal(
|
||||
output, expected
|
||||
), f"""
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{output}
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{expected}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x[0:1],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
},
|
||||
[(np.random.randint(0, 2 ** 3, size=(3,)),) for _ in range(10)],
|
||||
(
|
||||
"Indexing of EncryptedTensor<uint3, shape=(3,)> with [0:1] "
|
||||
"cannot be converted to MLIR yet"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_indexing_failed_compilation(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
match,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test compilation failures of compiled function with constant indexing"""
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == match, str(excinfo.value)
|
||||
|
||||
Reference in New Issue
Block a user