mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compilation): raise the appropriate error for intermediate signed integers
This commit is contained in:
@@ -83,16 +83,20 @@ def value_is_scalar(value_to_check: BaseValue) -> bool:
|
||||
return isinstance(value_to_check, TensorValue) and value_to_check.is_scalar
|
||||
|
||||
|
||||
def value_is_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is of type Integer.
|
||||
def value_is_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is of type Integer and is unsigned.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is of type Integer
|
||||
bool: True if the passed value_to_check is of type Integer and is unsigned
|
||||
"""
|
||||
return isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
|
||||
return (
|
||||
isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
and not cast(Integer, value_to_check.dtype).is_signed
|
||||
)
|
||||
|
||||
|
||||
def value_is_encrypted_tensor_integer(value_to_check: BaseValue) -> bool:
|
||||
|
||||
@@ -7,17 +7,86 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
value_is_integer,
|
||||
value_is_scalar,
|
||||
value_is_unsigned_integer,
|
||||
)
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..debugging.custom_assert import assert_not_reached, assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate
|
||||
from ..representation.intermediate import IntermediateNode, UnivariateFunction
|
||||
|
||||
# TODO: should come from compiler, through an API, #402
|
||||
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
|
||||
|
||||
|
||||
def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) -> Optional[str]:
|
||||
"""Check if node is compatible with MLIR.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): node to check
|
||||
is_output (bool): whether the node is an output node or not
|
||||
|
||||
Returns:
|
||||
Optional[str]: None if the node is compatible else reason for incompatibility
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-return-statements
|
||||
|
||||
inputs = node.inputs
|
||||
outputs = node.outputs
|
||||
|
||||
if isinstance(node, intermediate.Add): # constraints for addition
|
||||
for inp in inputs:
|
||||
if not value_is_scalar(inp):
|
||||
return "only scalar addition is supported"
|
||||
|
||||
elif isinstance(node, intermediate.Sub): # constraints for subtraction
|
||||
for inp in inputs:
|
||||
if not value_is_scalar(inp):
|
||||
return "only scalar subtraction is supported"
|
||||
|
||||
elif isinstance(node, intermediate.Mul): # constraints for multiplication
|
||||
for inp in inputs:
|
||||
if not value_is_scalar(inp):
|
||||
return "only scalar multiplication is supported"
|
||||
|
||||
elif isinstance(node, intermediate.Input): # constraints for inputs
|
||||
assert_true(len(outputs) == 1)
|
||||
if not value_is_unsigned_integer(outputs[0]):
|
||||
return "only unsigned integer inputs are supported"
|
||||
|
||||
elif isinstance(node, intermediate.Constant): # constraints for constants
|
||||
assert_true(len(outputs) == 1)
|
||||
if not value_is_unsigned_integer(outputs[0]):
|
||||
return "only unsigned integer constants are supported"
|
||||
|
||||
elif isinstance(node, intermediate.UnivariateFunction): # constraints for univariate functions
|
||||
assert_true(len(inputs) == 1)
|
||||
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
|
||||
return "only unsigned integer scalar lookup tables are supported"
|
||||
|
||||
elif isinstance(node, intermediate.Dot): # constraints for dot product
|
||||
assert_true(len(inputs) == 2)
|
||||
if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]):
|
||||
return "only unsigned integer dot product is supported"
|
||||
|
||||
else: # pragma: no cover
|
||||
assert_not_reached("Non IntermediateNode object in the OPGraph")
|
||||
|
||||
if is_output:
|
||||
for out in outputs:
|
||||
if not value_is_scalar(out) or not value_is_unsigned_integer(out):
|
||||
return "only scalar unsigned integer outputs are supported"
|
||||
else:
|
||||
for out in outputs:
|
||||
if not value_is_unsigned_integer(out):
|
||||
return "only unsigned integer intermediates are supported"
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-return-statements
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_graph_values_compatibility_with_mlir(
|
||||
op_graph: OPGraph,
|
||||
) -> Optional[Dict[IntermediateNode, str]]:
|
||||
@@ -33,13 +102,10 @@ def check_graph_values_compatibility_with_mlir(
|
||||
|
||||
offending_nodes = {}
|
||||
|
||||
for out_node in op_graph.output_nodes.values():
|
||||
for out in out_node.outputs:
|
||||
if not value_is_scalar(out):
|
||||
offending_nodes[out_node] = "non scalar outputs aren't supported"
|
||||
|
||||
if value_is_integer(out) and cast(Integer, out.dtype).is_signed:
|
||||
offending_nodes[out_node] = "signed integer outputs aren't supported"
|
||||
for node in op_graph.graph.nodes:
|
||||
is_output = node in op_graph.output_nodes.values()
|
||||
if (reason := check_node_compatibility_with_mlir(node, is_output)) is not None:
|
||||
offending_nodes[node] = reason
|
||||
|
||||
return None if len(offending_nodes) == 0 else offending_nodes
|
||||
|
||||
|
||||
@@ -163,20 +163,6 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
# Add the initial graph as an artifact
|
||||
compilation_artifacts.add_operation_graph("final", op_graph)
|
||||
|
||||
# Make sure the graph can be lowered to MLIR
|
||||
offending_nodes = check_graph_values_compatibility_with_mlir(op_graph)
|
||||
if offending_nodes is not None:
|
||||
raise RuntimeError(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n\n"
|
||||
+ get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
# Update bit_width for MLIR
|
||||
update_bit_width_for_mlir(op_graph)
|
||||
|
||||
# TODO: workaround extend LUT #359
|
||||
extend_direct_lookup_tables(op_graph)
|
||||
|
||||
return op_graph
|
||||
|
||||
|
||||
@@ -244,6 +230,33 @@ def compile_numpy_function_into_op_graph(
|
||||
raise
|
||||
|
||||
|
||||
def prepare_op_graph_for_mlir(op_graph):
|
||||
"""Prepare OPGraph for MLIR lowering.
|
||||
|
||||
This includes checking compatibility, changing bit-widths, and modifying lookup tables.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The operation graph to prepare
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Make sure the graph can be lowered to MLIR
|
||||
offending_nodes = check_graph_values_compatibility_with_mlir(op_graph)
|
||||
if offending_nodes is not None:
|
||||
raise RuntimeError(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n\n"
|
||||
+ get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
# Update bit_width for MLIR
|
||||
update_bit_width_for_mlir(op_graph)
|
||||
|
||||
# TODO: workaround extend LUT #359
|
||||
extend_direct_lookup_tables(op_graph)
|
||||
|
||||
|
||||
def _compile_numpy_function_internal(
|
||||
function_to_compile: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
@@ -281,6 +294,8 @@ def _compile_numpy_function_internal(
|
||||
compilation_artifacts,
|
||||
)
|
||||
|
||||
prepare_op_graph_for_mlir(op_graph)
|
||||
|
||||
# Convert graph to an MLIR representation
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(op_graph)
|
||||
|
||||
@@ -45,9 +45,9 @@ return(%2)
|
||||
with_types
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<Integer<signed, 6 bits>>
|
||||
%0 = x # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 7 bits>>
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
return(%2)
|
||||
|
||||
@@ -81,9 +81,9 @@ return(%2)
|
||||
with_types
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<Integer<signed, 6 bits>>
|
||||
%0 = x # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 7 bits>>
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
|
||||
return(%2)
|
||||
|
||||
@@ -13,7 +13,7 @@ from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
|
||||
from concrete.common.values import ClearScalar, EncryptedScalar
|
||||
from concrete.common.values.tensors import ClearTensor, EncryptedTensor
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph, prepare_op_graph_for_mlir
|
||||
from concrete.numpy.np_mlir_converter import NPMLIRConverter
|
||||
|
||||
|
||||
@@ -225,6 +225,7 @@ def test_mlir_converter(func, args_dict, args_ranges, default_compilation_config
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
prepare_op_graph_for_mlir(result_graph)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
@@ -270,6 +271,7 @@ def test_mlir_converter_dot_between_vectors(
|
||||
),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
prepare_op_graph_for_mlir(result_graph)
|
||||
converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
@@ -291,6 +293,7 @@ def test_mlir_converter_dot_vector_and_constant(default_compilation_configuratio
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
prepare_op_graph_for_mlir(left_graph)
|
||||
left_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
left_mlir = left_converter.convert(left_graph)
|
||||
|
||||
@@ -300,6 +303,7 @@ def test_mlir_converter_dot_vector_and_constant(default_compilation_configuratio
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2,)),) for _ in range(10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
prepare_op_graph_for_mlir(right_graph)
|
||||
right_converter = NPMLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
right_mlir = right_converter.convert(right_graph)
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
from concrete.numpy.compile import compile_numpy_function, compile_numpy_function_into_op_graph
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
|
||||
def no_fuse_unhandled(x, y):
|
||||
"""No fuse unhandled"""
|
||||
@@ -746,7 +748,21 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
"%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integer outputs aren't supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + (-1),
|
||||
{"x": EncryptedScalar(Integer(4, is_signed=False))},
|
||||
[(i,) for i in range(1, 2 ** 4)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%1 = Constant(-1) # ClearScalar<Integer<signed, 2 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
@@ -760,7 +776,79 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ non scalar outputs aren't supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%2 = Mul(%0, %1) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar multiplication is supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 127 - x,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = Constant(127) # ClearScalar<Integer<unsigned, 7 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%1 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%2 = Sub(%0, %1) # EncryptedTensor<Integer<unsigned, 7 bits>, shape=(2, 2)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar subtraction is supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y), # pylint: disable=unnecessary-lambda
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=True), shape=(1,)),
|
||||
"y": EncryptedTensor(Integer(2, is_signed=True), shape=(1,)),
|
||||
},
|
||||
[
|
||||
(numpy.array([-1]), numpy.array([-1])),
|
||||
(numpy.array([-1]), numpy.array([0])),
|
||||
(numpy.array([0]), numpy.array([-1])),
|
||||
(numpy.array([0]), numpy.array([0])),
|
||||
(numpy.array([1]), numpy.array([1])),
|
||||
(numpy.array([1]), numpy.array([0])),
|
||||
(numpy.array([0]), numpy.array([1])),
|
||||
(numpy.array([0]), numpy.array([0])),
|
||||
(numpy.array([-2]), numpy.array([-2])),
|
||||
(numpy.array([-2]), numpy.array([1])),
|
||||
],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%1 = y # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%2 = Dot(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
@@ -769,15 +857,58 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
def test_fail_compile(function, parameters, inputset, match, default_compilation_configuration):
|
||||
"""Test function compile_numpy_function_into_op_graph for a program with signed values"""
|
||||
|
||||
try:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
except RuntimeError as error:
|
||||
assert str(error) == match
|
||||
with pytest.raises(RuntimeError):
|
||||
try:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
except RuntimeError as error:
|
||||
assert str(error) == match
|
||||
raise
|
||||
|
||||
|
||||
def test_fail_with_intermediate_signed_values(default_compilation_configuration):
|
||||
"""Test function with failing compilation due to intermediate signed integers."""
|
||||
|
||||
def function(x, y):
|
||||
z = numpy.abs(10 * numpy.negative(x))
|
||||
z = z.astype(numpy.int32) + y
|
||||
return z
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
try:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
{
|
||||
"x": EncryptedScalar(Integer(2, is_signed=False)),
|
||||
"y": EncryptedScalar(Integer(2, is_signed=False)),
|
||||
},
|
||||
[(i, j) for i in range(2 ** 2) for j in range(2 ** 2)],
|
||||
default_compilation_configuration,
|
||||
show_mlir=True,
|
||||
)
|
||||
except RuntimeError as error:
|
||||
match = (
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = y # EncryptedScalar<Integer<unsigned, 2 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%1 = Constant(10) # ClearScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%2 = x # EncryptedScalar<Integer<unsigned, 2 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"%7 = Add(%6, %0) # EncryptedScalar<Integer<unsigned, 6 bits>>\n" # noqa: E501 # pylint: disable=line-too-long
|
||||
"return(%7)\n"
|
||||
)
|
||||
assert str(error) == match
|
||||
raise
|
||||
|
||||
|
||||
def test_small_inputset_no_fail():
|
||||
@@ -818,9 +949,9 @@ def test_small_inputset_treat_warnings_as_errors():
|
||||
# Remark that, when you do the dot of tensors of 4 values between 0 and 3,
|
||||
# you can get a maximal value of 4*3*3 = 36, ie something on 6 bits
|
||||
"%0 = x "
|
||||
"# EncryptedTensor<Integer<unsigned, 6 bits>, shape=(4,)>"
|
||||
"# EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)>"
|
||||
"\n%1 = y "
|
||||
"# EncryptedTensor<Integer<unsigned, 6 bits>, shape=(4,)>"
|
||||
"# EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)>"
|
||||
"\n%2 = Dot(%0, %1) "
|
||||
"# EncryptedScalar<Integer<unsigned, 6 bits>>"
|
||||
"\nreturn(%2)\n",
|
||||
|
||||
Reference in New Issue
Block a user