feat: support unsigned levelled operations on large bit-widths

This commit is contained in:
Umut
2022-08-29 10:24:05 +02:00
parent 514780f7b7
commit 7415dd07e1
7 changed files with 84 additions and 28 deletions

View File

@@ -16,5 +16,5 @@ from .compilation import (
compiler,
)
from .extensions import LookupTable, array, one, ones, univariate, zero, zeros
from .mlir.utils import MAXIMUM_BIT_WIDTH
from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from .representation import Graph

View File

@@ -165,7 +165,7 @@ class Client:
arg >= 0,
arg,
2 * (expected_max + 1) + arg,
).astype(np.uint8)
).astype(np.uint64)
if not is_valid:
actual_value = Value.of(arg, is_encrypted=is_encrypted)
@@ -220,10 +220,10 @@ class Client:
else:
result = result.astype(np.longlong) # to prevent overflows in numpy
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
sanitized_results.append(sanititzed_result.astype(np.int8))
sanitized_results.append(sanititzed_result.astype(np.int64))
else:
sanitized_results.append(
result if isinstance(result, int) else result.astype(np.uint8)
result if isinstance(result, int) else result.astype(np.uint64)
)
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)

View File

@@ -28,7 +28,7 @@ from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..values import ClearScalar, EncryptedScalar
from .node_converter import NodeConverter
from .utils import MAXIMUM_BIT_WIDTH
from .utils import MAXIMUM_TLU_BIT_WIDTH
# pylint: enable=no-member,no-name-in-module
@@ -221,6 +221,10 @@ class GraphConverter:
offending_nodes: Dict[Node, List[str]] = {}
max_bit_width = 0
first_tlu_node = None
first_signed_node = None
for node in graph.graph.nodes:
dtype = node.output.dtype
assert_that(isinstance(dtype, Integer))
@@ -230,10 +234,23 @@ class GraphConverter:
)
max_bit_width = max(max_bit_width, current_node_bit_width)
if current_node_bit_width > MAXIMUM_BIT_WIDTH:
offending_nodes[node] = [
f"only up to {MAXIMUM_BIT_WIDTH}-bit integers are supported"
]
if node.converted_to_table_lookup and first_tlu_node is None:
first_tlu_node = node
if dtype.is_signed and first_signed_node is None:
first_signed_node = node
if first_tlu_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
offending_nodes[first_tlu_node] = [
f"table lookups are only supported on circuits with "
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers"
]
if first_signed_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH:
offending_nodes[first_signed_node] = [
f"signed values are only supported on circuits with "
f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers"
]
if len(offending_nodes) != 0:
raise RuntimeError(

View File

@@ -13,7 +13,7 @@ from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Node, Operation
MAXIMUM_BIT_WIDTH = 8
MAXIMUM_TLU_BIT_WIDTH = 8
class HashableNdarray:

View File

@@ -201,7 +201,7 @@ class Helpers:
sample = []
for description in parameters.values():
minimum, maximum = description.get("range", [0, 127])
minimum, maximum = description.get("range", [0, (2**16) - 1])
if "shape" in description:
shape = description["shape"]

View File

@@ -58,6 +58,15 @@ import concrete.numpy as cnp
{
"x": {"range": [-50, 10], "status": "encrypted", "shape": (2, 3)},
},
{
"x": {"range": [0, 1000000], "status": "encrypted"},
},
{
"x": {"range": [0, 1000000], "status": "encrypted", "shape": (3,)},
},
{
"x": {"range": [0, 1000000], "status": "encrypted", "shape": (2, 3)},
},
],
)
def test_constant_add(function, parameters, helpers):

View File

@@ -344,23 +344,6 @@ Function you are trying to compile cannot be converted to MLIR
%1 = [3] # ClearTensor<uint2, shape=(1,)>
%2 = maximum(%0, %1) # ClearTensor<uint2, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of the operands must be encrypted
return %2
""", # noqa: E501
),
pytest.param(
lambda x: x + 200,
{"x": "encrypted"},
range(200),
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR:
%0 = x # EncryptedScalar<uint8>
%1 = 200 # ClearScalar<uint8>
%2 = add(%0, %1) # EncryptedScalar<uint9>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 8-bit integers are supported
return %2
""", # noqa: E501
@@ -413,6 +396,53 @@ Function you are trying to compile cannot be converted to MLIR
return %2
""", # noqa: E501
),
pytest.param(
lambda x: np.abs(10 * np.sin(x + 300)).astype(np.int64),
{"x": "encrypted"},
range(200),
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR:
%0 = x # EncryptedScalar<uint8>
%1 = 300 # ClearScalar<uint9>
%2 = add(%0, %1) # EncryptedScalar<uint9>
%3 = subgraph(%2) # EncryptedScalar<uint4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 8-bit integers
return %3
Subgraphs:
%3 = subgraph(%2):
%0 = 10 # ClearScalar<uint4>
%1 = input # EncryptedScalar<uint2>
%2 = sin(%1) # EncryptedScalar<float64>
%3 = multiply(%0, %2) # EncryptedScalar<float64>
%4 = absolute(%3) # EncryptedScalar<float64>
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
return %5
""", # noqa: E501
),
pytest.param(
lambda x: x - 300,
{"x": "encrypted"},
range(200),
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR:
%0 = x # EncryptedScalar<uint8>
%1 = 300 # ClearScalar<uint9>
%2 = subtract(%0, %1) # EncryptedScalar<int10>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed values are only supported on circuits with up to 8-bit integers
return %2
""", # noqa: E501
),
],