feat: manage signed constants

closes #688
closes #612
This commit is contained in:
Benoit Chevallier-Mames
2021-10-22 14:43:34 +02:00
committed by Benoit Chevallier
parent 624143106f
commit 9459675cfb
5 changed files with 80 additions and 45 deletions

View File

@@ -83,6 +83,19 @@ def value_is_scalar(value_to_check: BaseValue) -> bool:
return isinstance(value_to_check, TensorValue) and value_to_check.is_scalar
def value_is_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is of type Integer.
Args:
value_to_check (BaseValue): The value to check
Returns:
bool: True if the passed value_to_check is of type Integer
"""
return isinstance(value_to_check.dtype, INTEGER_TYPES)
def value_is_unsigned_integer(value_to_check: BaseValue) -> bool:
"""Check that a value is of type Integer and is unsigned.

View File

@@ -17,6 +17,7 @@ from zamalang.dialects import hlfhe
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
value_is_clear_tensor_integer,
value_is_encrypted_scalar_integer,
value_is_encrypted_scalar_unsigned_integer,
value_is_encrypted_tensor_integer,
)
@@ -30,18 +31,18 @@ def add(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert an addition intermediate node."""
assert_true(len(node.inputs) == 2, "addition should have two inputs")
assert_true(len(node.outputs) == 1, "addition should have a single output")
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_clear_scalar_integer(
node.inputs[1]
):
return _add_eint_int(node, preds, ir_to_mlir_node, ctx)
if value_is_encrypted_scalar_unsigned_integer(node.inputs[1]) and value_is_clear_scalar_integer(
if value_is_encrypted_scalar_integer(node.inputs[1]) and value_is_clear_scalar_integer(
node.inputs[0]
):
# flip lhs and rhs
return _add_eint_int(node, preds[::-1], ir_to_mlir_node, ctx)
if value_is_encrypted_scalar_unsigned_integer(
node.inputs[0]
) and value_is_encrypted_scalar_unsigned_integer(node.inputs[1]):
if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_integer(
node.inputs[1]
):
return _add_eint_eint(node, preds, ir_to_mlir_node, ctx)
raise TypeError(
f"Don't support addition between {str(node.inputs[0])} and {str(node.inputs[1])}"
@@ -74,7 +75,7 @@ def sub(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a subtraction intermediate node."""
assert_true(len(node.inputs) == 2, "subtraction should have two inputs")
assert_true(len(node.outputs) == 1, "subtraction should have a single output")
if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_unsigned_integer(
if value_is_clear_scalar_integer(node.inputs[0]) and value_is_encrypted_scalar_integer(
node.inputs[1]
):
return _sub_int_eint(node, preds, ir_to_mlir_node, ctx)
@@ -98,11 +99,11 @@ def mul(node, preds, ir_to_mlir_node, ctx, _additional_conversion_info=None):
"""Convert a multiplication intermediate node."""
assert_true(len(node.inputs) == 2, "multiplication should have two inputs")
assert_true(len(node.outputs) == 1, "multiplication should have a single output")
if value_is_encrypted_scalar_unsigned_integer(node.inputs[0]) and value_is_clear_scalar_integer(
if value_is_encrypted_scalar_integer(node.inputs[0]) and value_is_clear_scalar_integer(
node.inputs[1]
):
return _mul_eint_int(node, preds, ir_to_mlir_node, ctx)
if value_is_encrypted_scalar_unsigned_integer(node.inputs[1]) and value_is_clear_scalar_integer(
if value_is_encrypted_scalar_integer(node.inputs[1]) and value_is_clear_scalar_integer(
node.inputs[0]
):
# flip lhs and rhs
@@ -131,8 +132,6 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No
value = cast(TensorValue, value)
dtype = cast(Integer, value.dtype)
if dtype.is_signed:
raise TypeError("Don't support signed constant integer")
data = node.constant_data
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)
@@ -142,8 +141,6 @@ def constant(node, _preds, _ir_to_mlir_node, ctx, _additional_conversion_info=No
value = cast(TensorValue, value)
dtype = cast(Integer, value.dtype)
if dtype.is_signed:
raise TypeError("Don't support signed constant integer tensor")
data = node.constant_data
int_type = IntegerType.get_signless(dtype.bit_width, context=ctx)

View File

@@ -7,6 +7,7 @@ from ..data_types.dtypes_helpers import (
value_is_clear_tensor_integer,
value_is_encrypted_scalar_integer,
value_is_encrypted_tensor_integer,
value_is_integer,
value_is_scalar,
value_is_unsigned_integer,
)
@@ -58,8 +59,10 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
elif isinstance(node, intermediate.Constant): # constraints for constants
assert_true(len(outputs) == 1)
if not value_is_unsigned_integer(outputs[0]):
return "only unsigned integer constants are supported"
# We currently can't fail on the following assert, but let it for possible changes in the
# future
if not value_is_integer(outputs[0]):
return "only integer constants are supported" # pragma: no cover
elif isinstance(node, intermediate.UnivariateFunction): # constraints for univariate functions
assert_true(len(inputs) == 1)
@@ -84,8 +87,10 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
return "only scalar unsigned integer outputs are supported"
else:
for out in outputs:
if not value_is_unsigned_integer(out):
return "only unsigned integer intermediates are supported"
# We currently can't fail on the following assert, but let it for possible changes in
# the future
if not value_is_integer(out):
return "only integer intermediates are supported" # pragma: no cover
# pylint: enable=too-many-branches,too-many-return-statements

View File

@@ -37,14 +37,6 @@ def test_fail_non_integer_const():
constant(MockNode(outputs=[ClearTensor(Float(32), shape=(2,))]), None, None, None)
def test_fail_signed_integer_const():
"""Test failing constant converter with non-integer"""
with pytest.raises(TypeError, match=r"Don't support signed constant integer"):
constant(MockNode(outputs=[ClearScalar(Integer(8, True))]), None, None, None)
with pytest.raises(TypeError, match=r"Don't support signed constant integer tensor"):
constant(MockNode(outputs=[ClearTensor(Integer(8, True), shape=(2,))]), None, None, None)
@pytest.mark.parametrize(
"input_node",
[

View File

@@ -482,6 +482,15 @@ def test_compile_function_multiple_outputs(
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: (-27) + 4 * (x + 8), ((0, 10),), ["x"]),
pytest.param(lambda x: x + (-33), ((40, 60),), ["x"]),
pytest.param(lambda x: 17 - (0 - x), ((0, 10),), ["x"]),
pytest.param(lambda x: 42 + x * (-3), ((0, 10),), ["x"]),
pytest.param(lambda x: 43 + (-4) * x, ((0, 10),), ["x"]),
pytest.param(lambda x: 3 - (-5) * x, ((0, 10),), ["x"]),
pytest.param(lambda x: (-2) * (-5) * x, ((0, 10),), ["x"]),
pytest.param(lambda x: (-2) * x * (-5), ((0, 10),), ["x"]),
pytest.param(lambda x, y: 40 - (-3 * x) + (-2 * y), ((0, 20), (0, 20)), ["x", "y"]),
pytest.param(lambda x: x + numpy.int32(42), ((0, 10),), ["x"]),
pytest.param(lambda x: x + 64, ((0, 10),), ["x"]),
pytest.param(lambda x: x * 3, ((0, 40),), ["x"]),
@@ -747,20 +756,6 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
"return(%2)\n"
),
),
pytest.param(
lambda x: x + (-1),
{"x": EncryptedScalar(Integer(4, is_signed=False))},
[(i,) for i in range(1, 2 ** 4)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"%1 = Constant(-1) # ClearScalar<Integer<signed, 2 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501
"%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"return(%2)\n"
),
),
pytest.param(
lambda x: x + 1,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
@@ -869,16 +864,16 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
"function you are trying to compile isn't supported for MLIR lowering\n\n"
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"%1 = Constant(2.8) # ClearScalar<Float<64 bits>>\n"
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported\n" # noqa: E501
"%2 = y # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"%3 = Constant(9.3) # ClearScalar<Float<64 bits>>\n"
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer constants are supported\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported\n" # noqa: E501
"%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported\n" # noqa: E501
"%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported\n" # noqa: E501
"%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported\n" # noqa: E501
"%7 = astype(int32)(%6) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501
"return(%7)\n"
@@ -930,9 +925,7 @@ def test_fail_with_intermediate_signed_values(default_compilation_configuration)
"%1 = Constant(10) # ClearScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"%2 = x # EncryptedScalar<Integer<unsigned, 2 bits>>\n" # noqa: E501
"%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer intermediates are supported\n" # noqa: E501
"%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported\n" # noqa: E501
"%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
@@ -1103,3 +1096,38 @@ def test_compile_too_high_bitwidth(default_compilation_configuration):
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
def test_failure_for_signed_output(default_compilation_configuration):
"""Test that we don't accept signed output"""
function = lambda x: x + (-3) # pylint: disable=unnecessary-lambda # noqa: E731
input_ranges = ((0, 10),)
list_of_arg_names = ["x"]
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
}
with pytest.raises(RuntimeError) as excinfo:
compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
# pylint: disable=line-too-long
assert (
str(excinfo.value)
== "function you are trying to compile isn't supported for MLIR lowering\n\n" # noqa: E501
"%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>\n" # noqa: E501
"%1 = Constant(-3) # ClearScalar<Integer<signed, 3 bits>>\n" # noqa: E501
"%2 = Add(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported\n" # noqa: E501
"return(%2)\n"
)
# pylint: enable=line-too-long