mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
committed by
Benoit Chevallier
parent
624143106f
commit
9459675cfb
@@ -83,6 +83,19 @@ def value_is_scalar(value_to_check: BaseValue) -> bool:
|
||||
return isinstance(value_to_check, TensorValue) and value_to_check.is_scalar
|
||||
|
||||
|
||||
def value_is_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is of type Integer.
|
||||
|
||||
Args:
|
||||
value_to_check (BaseValue): The value to check
|
||||
|
||||
Returns:
|
||||
bool: True if the passed value_to_check is of type Integer
|
||||
"""
|
||||
|
||||
return isinstance(value_to_check.dtype, INTEGER_TYPES)
|
||||
|
||||
|
||||
def value_is_unsigned_integer(value_to_check: BaseValue) -> bool:
|
||||
"""Check that a value is of type Integer and is unsigned.
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from zamalang.dialects import hlfhe
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_integer,
|
||||
value_is_encrypted_scalar_unsigned_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
)
|
||||
@@ -30,18 +31,18 @@ def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert an addition intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "addition should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "addition should have a single output")
|
||||
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _add_eint_int(node, preds, ir_to_mlir_node, ctx)
|
||||
if value_is_encrypted_scalar_unsigned_integer(node.inputs[1]) and value_is_clear_scalar_integer(
|
||||
if value_is_encrypted_scalar_integer(node.inputs[1]) and value_is_clear_scalar_integer(
|
||||
node.inputs[0]
|
||||
):
|
||||
# flip lhs and rhs
|
||||
return _add_eint_int(node, preds[::-1], ir_to_mlir_node, ctx)
|
||||
if value_is_encrypted_scalar_unsigned_integer(
|
||||
node.inputs[0]
|
||||
) and value_is_encrypted_scalar_unsigned_integer(node.inputs[1]):
|
||||
if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _add_eint_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
raise TypeError(
|
||||
f"Don't support addition between {str(node.inputs[0])} and {str(node.inputs[1])}"
|
||||
@@ -74,7 +75,7 @@ def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a subtraction intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "subtraction should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "subtraction should have a single output")
|
||||
if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_unsigned_integer(
|
||||
if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _sub_int_eint(node, preds, ir_to_mlir_node, ctx)
|
||||
@@ -98,11 +99,11 @@ def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
|
||||
"""Convert a multiplication intermediate node."""
|
||||
assert_true(len(node.inputs) == 2, "multiplication should have two inputs")
|
||||
assert_true(len(node.outputs) == 1, "multiplication should have a single output")
|
||||
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_clear_scalar_integer(
|
||||
node.inputs[1]
|
||||
):
|
||||
return _mul_eint_int(node, preds, ir_to_mlir_node, ctx)
|
||||
if value_is_encrypted_scalar_unsigned_integer(node.inputs[1]) and value_is_clear_scalar_integer(
|
||||
if value_is_encrypted_scalar_integer(node.inputs[1]) and value_is_clear_scalar_integer(
|
||||
node.inputs[0]
|
||||
):
|
||||
# flip lhs and rhs
|
||||
@@ -131,8 +132,6 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No
|
||||
value = cast(TensorValue, value)
|
||||
|
||||
dtype = cast(Integer, value.dtype)
|
||||
if dtype.is_signed:
|
||||
raise TypeError("Don't support signed constant integer")
|
||||
data = node.constant_data
|
||||
|
||||
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
|
||||
@@ -142,8 +141,6 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No
|
||||
value = cast(TensorValue, value)
|
||||
|
||||
dtype = cast(Integer, value.dtype)
|
||||
if dtype.is_signed:
|
||||
raise TypeError("Don't support signed constant integer tensor")
|
||||
data = node.constant_data
|
||||
|
||||
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_tensor_integer,
|
||||
value_is_encrypted_scalar_integer,
|
||||
value_is_encrypted_tensor_integer,
|
||||
value_is_integer,
|
||||
value_is_scalar,
|
||||
value_is_unsigned_integer,
|
||||
)
|
||||
@@ -58,8 +59,10 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
|
||||
|
||||
elif isinstance(node, intermediate.Constant): # constraints for constants
|
||||
assert_true(len(outputs) == 1)
|
||||
if not value_is_unsigned_integer(outputs[0]):
|
||||
return "only unsigned integer constants are supported"
|
||||
# We currently can't fail on the following assert, but let it for possible changes in the
|
||||
# future
|
||||
if not value_is_integer(outputs[0]):
|
||||
return "only integer constants are supported" # pragma: no cover
|
||||
|
||||
elif isinstance(node, intermediate.UnivariateFunction): # constraints for univariate functions
|
||||
assert_true(len(inputs) == 1)
|
||||
@@ -84,8 +87,10 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
|
||||
return "only scalar unsigned integer outputs are supported"
|
||||
else:
|
||||
for out in outputs:
|
||||
if not value_is_unsigned_integer(out):
|
||||
return "only unsigned integer intermediates are supported"
|
||||
# We currently can't fail on the following assert, but let it for possible changes in
|
||||
# the future
|
||||
if not value_is_integer(out):
|
||||
return "only integer intermediates are supported" # pragma: no cover
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-return-statements
|
||||
|
||||
|
||||
@@ -37,14 +37,6 @@ def test_fail_non_integer_const():
|
||||
constant(MockNode(outputs=[ClearTensor(Float(32), shape=(2,))]), None, None, None)
|
||||
|
||||
|
||||
def test_fail_signed_integer_const():
|
||||
"""Test failing constant converter with non-integer"""
|
||||
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
|
||||
constant(MockNode(outputs=[ClearScalar(Integer(8, True))]), None, None, None)
|
||||
with pytest.raises(TypeError, match=r"Don't support signed constant integer tensor"):
|
||||
constant(MockNode(outputs=[ClearTensor(Integer(8, True), shape=(2,))]), None, None, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_node",
|
||||
[
|
||||
|
||||
@@ -482,6 +482,15 @@ def test_compile_function_multiple_outputs(
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: (-27) + 4 * (x + 8), ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x + (-33), ((40, 60),), ["x"]),
|
||||
pytest.param(lambda x: 17 - (0 - x), ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: 42 + x * (-3), ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: 43 + (-4) * x, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: 3 - (-5) * x, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: (-2) * (-5) * x, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: (-2) * x * (-5), ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x, y: 40 - (-3 * x) + (-2 * y), ((0, 20), (0, 20)), ["x", "y"]),
|
||||
pytest.param(lambda x: x + numpy.int32(42), ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x + 64, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: x * 3, ((0, 40),), ["x"]),
|
||||
@@ -747,20 +756,6 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + (-1),
|
||||
{"x": EncryptedScalar(Integer(4, is_signed=False))},
|
||||
[(i,) for i in range(1, 2 ** 4)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
|
||||
"%1 = Constant(-1) # ClearScalar<Integer<signed, 2 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501
|
||||
"%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
@@ -869,16 +864,16 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n\n"
|
||||
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
|
||||
"%1 = Constant(2.8) # ClearScalar<Float<64 bits>>\n"
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported\n" # noqa: E501
|
||||
"%2 = y # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
|
||||
"%3 = Constant(9.3) # ClearScalar<Float<64 bits>>\n"
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported\n" # noqa: E501
|
||||
"%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported\n" # noqa: E501
|
||||
"%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported\n" # noqa: E501
|
||||
"%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported\n" # noqa: E501
|
||||
"%7 = astype(int32)(%6) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501
|
||||
"return(%7)\n"
|
||||
@@ -930,9 +925,7 @@ def test_fail_with_intermediate_signed_values(default_compilation_configuration)
|
||||
"%1 = Constant(10) # ClearScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
|
||||
"%2 = x # EncryptedScalar<Integer<unsigned, 2 bits>>\n" # noqa: E501
|
||||
"%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
|
||||
"%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
|
||||
"%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501
|
||||
"%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
|
||||
@@ -1103,3 +1096,38 @@ def test_compile_too_high_bitwidth(default_compilation_configuration):
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
|
||||
def test_failure_for_signed_output(default_compilation_configuration):
|
||||
"""Test that we don't accept signed output"""
|
||||
function = lambda x: x + (-3) # pylint: disable=unnecessary-lambda # noqa: E731
|
||||
input_ranges = ((0, 10),)
|
||||
list_of_arg_names = ["x"]
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "function you are trying to compile isn't supported for MLIR lowering\n\n" # noqa: E501
|
||||
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
|
||||
"%1 = Constant(-3) # ClearScalar<Integer<signed, 3 bits>>\n" # noqa: E501
|
||||
"%2 = Add(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
)
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
Reference in New Issue
Block a user