fix: don't allow signed clear inputs as tlu cannot be applied to them, add more signed input tests

This commit is contained in:
Umut
2022-06-01 14:49:58 +02:00
parent a6b09ddf09
commit 1e86c3b1e4
4 changed files with 40 additions and 0 deletions

View File

@@ -69,6 +69,8 @@ class GraphConverter:
assert_that(inputs[0] == output)
if not isinstance(output.dtype, Integer):
return "only integer inputs are supported"
if output.dtype.is_signed and output.is_clear:
return "only encrypted signed integer inputs are supported"
else:
assert_that(node.operation == Operation.Generic)
@@ -345,6 +347,8 @@ class GraphConverter:
input_dtype = cast(Integer, input_value.dtype)
if input_dtype.is_signed:
assert_that(input_value.is_encrypted)
n = input_dtype.bit_width
lut_range = np.arange(2**n)

View File

@@ -49,6 +49,15 @@ import concrete.numpy as cnp
{
"x": {"range": [0, 85], "status": "encrypted", "shape": (2, 3)},
},
{
"x": {"range": [-50, 10], "status": "encrypted"},
},
{
"x": {"range": [-50, 10], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [-50, 10], "status": "encrypted", "shape": (2, 3)},
},
],
)
def test_constant_add(function, parameters, helpers):
@@ -140,6 +149,10 @@ def test_constant_add(function, parameters, helpers):
"x": {"range": [0, 60], "status": "encrypted", "shape": (2, 1)},
"y": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [-30, 30], "status": "encrypted", "shape": (3, 2)},
"y": {"range": [-30, 30], "status": "encrypted", "shape": (3, 2)},
},
],
)
def test_add(function, parameters, helpers):

View File

@@ -578,6 +578,14 @@ def deterministic_unary_function(x):
},
id="(2.5 * round(np.sqrt(x), decimals=4)).astype(np.int64)",
),
pytest.param(
lambda x, y: cnp.LookupTable(list(range(32)))[x + y],
{
"x": {"status": "encrypted", "range": [-10, 10]},
"y": {"status": "encrypted", "range": [-10, 10]},
},
id="cnp.LookupTable(list(range(32)))[x + y]",
),
],
)
def test_others(function, parameters, helpers):

View File

@@ -33,6 +33,21 @@ return (%2, %3)
""", # noqa: E501
),
pytest.param(
lambda x: x,
{"x": "clear"},
range(-10, 10),
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearScalar<int5>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted signed integer inputs are supported
return %0
""", # noqa: E501
),
pytest.param(
lambda x: x * 1.5,
{"x": "encrypted"},