mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: support up to 16-bit table lookups
This commit is contained in:
@@ -28,7 +28,7 @@ from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import ClearScalar, EncryptedScalar
|
||||
from .node_converter import NodeConverter
|
||||
from .utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
from .utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
|
||||
|
||||
# pylint: enable=no-member,no-name-in-module
|
||||
|
||||
@@ -240,11 +240,19 @@ class GraphConverter:
|
||||
if dtype.is_signed and first_signed_node is None:
|
||||
first_signed_node = node
|
||||
|
||||
if first_tlu_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
offending_nodes[first_tlu_node] = [
|
||||
f"table lookups are only supported on circuits with "
|
||||
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers"
|
||||
]
|
||||
if first_tlu_node is not None:
|
||||
if max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
|
||||
offending_nodes[first_tlu_node] = [
|
||||
f"table lookups are only supported on circuits with "
|
||||
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers"
|
||||
]
|
||||
|
||||
if first_signed_node is not None and max_bit_width > MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS:
|
||||
offending_nodes[first_signed_node] = [
|
||||
f"signed integers are only supported "
|
||||
f"up to {MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS}-bits "
|
||||
f"on circuits with table lookups"
|
||||
]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -13,7 +13,8 @@ from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Node, Operation
|
||||
|
||||
MAXIMUM_TLU_BIT_WIDTH = 8
|
||||
MAXIMUM_TLU_BIT_WIDTH = 16
|
||||
MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS = 8
|
||||
|
||||
|
||||
class HashableNdarray:
|
||||
|
||||
@@ -401,17 +401,17 @@ return %2
|
||||
pytest.param(
|
||||
lambda x: np.abs(10 * np.sin(x + 300)).astype(np.int64),
|
||||
{"x": "encrypted"},
|
||||
range(200),
|
||||
[200000],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR:
|
||||
|
||||
%0 = x # EncryptedScalar<uint8>
|
||||
%0 = x # EncryptedScalar<uint18>
|
||||
%1 = 300 # ClearScalar<uint9>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint9>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint18>
|
||||
%3 = subgraph(%2) # EncryptedScalar<uint4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 8-bit integers
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bit integers
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
@@ -426,6 +426,35 @@ Subgraphs:
|
||||
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %5
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (10 * np.sin(x + 300)).astype(np.int64),
|
||||
{"x": "encrypted"},
|
||||
range(2**10, 2**11),
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR:
|
||||
|
||||
%0 = x # EncryptedScalar<uint11>
|
||||
%1 = 300 # ClearScalar<uint9>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint12>
|
||||
%3 = subgraph(%2) # EncryptedScalar<int5>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%3 = subgraph(%2):
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
%1 = input # EncryptedScalar<uint2>
|
||||
%2 = sin(%1) # EncryptedScalar<float64>
|
||||
%3 = multiply(%0, %2) # EncryptedScalar<float64>
|
||||
%4 = astype(%3, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %4
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user