feat: support up to 16-bit table lookups

This commit is contained in:
Umut
2022-10-12 10:56:50 +02:00
parent 191150b36d
commit 21a0923e2e
3 changed files with 49 additions and 11 deletions

View File

@@ -28,7 +28,7 @@ from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..values import ClearScalar, EncryptedScalar
from .node_converter import NodeConverter
from .utils import MAXIMUM_TLU_BIT_WIDTH
from .utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
# pylint: enable=no-member,no-name-in-module
@@ -240,11 +240,19 @@ class GraphConverter:
if dtype.is_signed and first_signed_node is None:
first_signed_node = node
if first_tlu_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
offending_nodes[first_tlu_node] = [
f"table lookups are only supported on circuits with "
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers"
]
if first_tlu_node is not None:
if max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
offending_nodes[first_tlu_node] = [
f"table lookups are only supported on circuits with "
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers"
]
if first_signed_node is not None and max_bit_width > MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS:
offending_nodes[first_signed_node] = [
f"signed integers are only supported "
f"up to {MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS}-bits "
f"on circuits with table lookups"
]
if len(offending_nodes) != 0:
raise RuntimeError(

View File

@@ -13,7 +13,8 @@ from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Node, Operation
MAXIMUM_TLU_BIT_WIDTH = 8
MAXIMUM_TLU_BIT_WIDTH = 16
MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS = 8
class HashableNdarray:

View File

@@ -401,17 +401,17 @@ return %2
pytest.param(
lambda x: np.abs(10 * np.sin(x + 300)).astype(np.int64),
{"x": "encrypted"},
range(200),
[200000],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR:
%0 = x # EncryptedScalar<uint8>
%0 = x # EncryptedScalar<uint18>
%1 = 300 # ClearScalar<uint9>
%2 = add(%0, %1) # EncryptedScalar<uint9>
%2 = add(%0, %1) # EncryptedScalar<uint18>
%3 = subgraph(%2) # EncryptedScalar<uint4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 8-bit integers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bit integers
return %3
Subgraphs:
@@ -426,6 +426,35 @@ Subgraphs:
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
return %5
""", # noqa: E501
),
pytest.param(
lambda x: (10 * np.sin(x + 300)).astype(np.int64),
{"x": "encrypted"},
range(2**10, 2**11),
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR:
%0 = x # EncryptedScalar<uint11>
%1 = 300 # ClearScalar<uint9>
%2 = add(%0, %1) # EncryptedScalar<uint12>
%3 = subgraph(%2) # EncryptedScalar<int5>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups
return %3
Subgraphs:
%3 = subgraph(%2):
%0 = 10 # ClearScalar<uint4>
%1 = input # EncryptedScalar<uint2>
%2 = sin(%1) # EncryptedScalar<float64>
%3 = multiply(%0, %2) # EncryptedScalar<float64>
%4 = astype(%3, dtype=int_) # EncryptedScalar<uint1>
return %4
""", # noqa: E501
),
],