mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: support unsigned levelled operations on large bit-widths
This commit is contained in:
@@ -16,5 +16,5 @@ from .compilation import (
|
||||
compiler,
|
||||
)
|
||||
from .extensions import LookupTable, array, one, ones, univariate, zero, zeros
|
||||
from .mlir.utils import MAXIMUM_BIT_WIDTH
|
||||
from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
from .representation import Graph
|
||||
|
||||
@@ -165,7 +165,7 @@ class Client:
|
||||
arg >= 0,
|
||||
arg,
|
||||
2 * (expected_max + 1) + arg,
|
||||
).astype(np.uint8)
|
||||
).astype(np.uint64)
|
||||
|
||||
if not is_valid:
|
||||
actual_value = Value.of(arg, is_encrypted=is_encrypted)
|
||||
@@ -220,10 +220,10 @@ class Client:
|
||||
else:
|
||||
result = result.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
|
||||
sanitized_results.append(sanititzed_result.astype(np.int8))
|
||||
sanitized_results.append(sanititzed_result.astype(np.int64))
|
||||
else:
|
||||
sanitized_results.append(
|
||||
result if isinstance(result, int) else result.astype(np.uint8)
|
||||
result if isinstance(result, int) else result.astype(np.uint64)
|
||||
)
|
||||
|
||||
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
|
||||
|
||||
@@ -28,7 +28,7 @@ from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import ClearScalar, EncryptedScalar
|
||||
from .node_converter import NodeConverter
|
||||
from .utils import MAXIMUM_BIT_WIDTH
|
||||
from .utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
|
||||
# pylint: enable=no-member,no-name-in-module
|
||||
|
||||
@@ -221,6 +221,10 @@ class GraphConverter:
|
||||
offending_nodes: Dict[Node, List[str]] = {}
|
||||
|
||||
max_bit_width = 0
|
||||
|
||||
first_tlu_node = None
|
||||
first_signed_node = None
|
||||
|
||||
for node in graph.graph.nodes:
|
||||
dtype = node.output.dtype
|
||||
assert_that(isinstance(dtype, Integer))
|
||||
@@ -230,10 +234,23 @@ class GraphConverter:
|
||||
)
|
||||
max_bit_width = max(max_bit_width, current_node_bit_width)
|
||||
|
||||
if current_node_bit_width > MAXIMUM_BIT_WIDTH:
|
||||
offending_nodes[node] = [
|
||||
f"only up to {MAXIMUM_BIT_WIDTH}-bit integers are supported"
|
||||
]
|
||||
if node.converted_to_table_lookup and first_tlu_node is None:
|
||||
first_tlu_node = node
|
||||
|
||||
if dtype.is_signed and first_signed_node is None:
|
||||
first_signed_node = node
|
||||
|
||||
if first_tlu_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
offending_nodes[first_tlu_node] = [
|
||||
f"table lookups are only supported on circuits with "
|
||||
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers"
|
||||
]
|
||||
|
||||
if first_signed_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
offending_nodes[first_signed_node] = [
|
||||
f"signed values are only supported on circuits with "
|
||||
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers"
|
||||
]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -13,7 +13,7 @@ from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Node, Operation
|
||||
|
||||
MAXIMUM_BIT_WIDTH = 8
|
||||
MAXIMUM_TLU_BIT_WIDTH = 8
|
||||
|
||||
|
||||
class HashableNdarray:
|
||||
|
||||
@@ -201,7 +201,7 @@ class Helpers:
|
||||
sample = []
|
||||
|
||||
for description in parameters.values():
|
||||
minimum, maximum = description.get("range", [0, 127])
|
||||
minimum, maximum = description.get("range", [0, (2**16) - 1])
|
||||
|
||||
if "shape" in description:
|
||||
shape = description["shape"]
|
||||
|
||||
@@ -58,6 +58,15 @@ import concrete.numpy as cnp
|
||||
{
|
||||
"x": {"range": [-50, 10], "status": "encrypted", "shape": (2, 3)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 1000000], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 1000000], "status": "encrypted", "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [0, 1000000], "status": "encrypted", "shape": (2, 3)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_constant_add(function, parameters, helpers):
|
||||
|
||||
@@ -344,23 +344,6 @@ Function you are trying to compile cannot be converted to MLIR
|
||||
%1 = [3] # ClearTensor<uint2, shape=(1,)>
|
||||
%2 = maximum(%0, %1) # ClearTensor<uint2, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of the operands must be encrypted
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 200,
|
||||
{"x": "encrypted"},
|
||||
range(200),
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR:
|
||||
|
||||
%0 = x # EncryptedScalar<uint8>
|
||||
%1 = 200 # ClearScalar<uint8>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint9>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 8-bit integers are supported
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
@@ -413,6 +396,53 @@ Function you are trying to compile cannot be converted to MLIR
|
||||
return %2
|
||||
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.abs(10 * np.sin(x + 300)).astype(np.int64),
|
||||
{"x": "encrypted"},
|
||||
range(200),
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR:
|
||||
|
||||
%0 = x # EncryptedScalar<uint8>
|
||||
%1 = 300 # ClearScalar<uint9>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint9>
|
||||
%3 = subgraph(%2) # EncryptedScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 8-bit integers
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%3 = subgraph(%2):
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
%1 = input # EncryptedScalar<uint2>
|
||||
%2 = sin(%1) # EncryptedScalar<float64>
|
||||
%3 = multiply(%0, %2) # EncryptedScalar<float64>
|
||||
%4 = absolute(%3) # EncryptedScalar<float64>
|
||||
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %5
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x - 300,
|
||||
{"x": "encrypted"},
|
||||
range(200),
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR:
|
||||
|
||||
%0 = x # EncryptedScalar<uint8>
|
||||
%1 = 300 # ClearScalar<uint9>
|
||||
%2 = subtract(%0, %1) # EncryptedScalar<int10>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed values are only supported on circuits with up to 8-bit integers
|
||||
return %2
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user