mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
fix: don't allow signed clear inputs as tlu cannot be applied to them, add more signed input tests
This commit is contained in:
@@ -69,6 +69,8 @@ class GraphConverter:
|
||||
assert_that(inputs[0] == output)
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer inputs are supported"
|
||||
if output.dtype.is_signed and output.is_clear:
|
||||
return "only encrypted signed integer inputs are supported"
|
||||
|
||||
else:
|
||||
assert_that(node.operation == Operation.Generic)
|
||||
@@ -345,6 +347,8 @@ class GraphConverter:
|
||||
input_dtype = cast(Integer, input_value.dtype)
|
||||
|
||||
if input_dtype.is_signed:
|
||||
assert_that(input_value.is_encrypted)
|
||||
|
||||
n = input_dtype.bit_width
|
||||
lut_range = np.arange(2**n)
|
||||
|
||||
|
||||
@@ -49,6 +49,15 @@ import concrete.numpy as cnp
|
||||
{
|
||||
"x": {"range": [0, 85], "status": "encrypted", "shape": (2, 3)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-50, 10], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-50, 10], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-50, 10], "status": "encrypted", "shape": (2, 3)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_constant_add(function, parameters, helpers):
|
||||
@@ -140,6 +149,10 @@ def test_constant_add(function, parameters, helpers):
|
||||
"x": {"range": [0, 60], "status": "encrypted", "shape": (2, 1)},
|
||||
"y": {"range": [0, 60], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-30, 30], "status": "encrypted", "shape": (3, 2)},
|
||||
"y": {"range": [-30, 30], "status": "encrypted", "shape": (3, 2)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_add(function, parameters, helpers):
|
||||
|
||||
@@ -578,6 +578,14 @@ def deterministic_unary_function(x):
|
||||
},
|
||||
id="(2.5 * round(np.sqrt(x), decimals=4)).astype(np.int64)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: cnp.LookupTable(list(range(32)))[x + y],
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10]},
|
||||
"y": {"status": "encrypted", "range": [-10, 10]},
|
||||
},
|
||||
id="cnp.LookupTable(list(range(32)))[x + y]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_others(function, parameters, helpers):
|
||||
|
||||
@@ -33,6 +33,21 @@ return (%2, %3)
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x,
|
||||
{"x": "clear"},
|
||||
range(-10, 10),
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearScalar<int5>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted signed integer inputs are supported
|
||||
return %0
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * 1.5,
|
||||
{"x": "encrypted"},
|
||||
|
||||
Reference in New Issue
Block a user