feat: support up to 16-bit signed integers with table lookups

This commit is contained in:
Umut
2022-11-21 14:27:28 +01:00
parent ad9c3beee7
commit 23337e9dbd
6 changed files with 49 additions and 59 deletions

View File

@@ -27,7 +27,7 @@ from .extensions import (
zero,
zeros,
)
from .mlir.utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from .representation import Graph
from .tracing.typing import (
f32,

View File

@@ -197,37 +197,58 @@ class Client:
"""
self.keygen(force=False)
results = ClientSupport.decrypt_result(self._keyset, result)
if not isinstance(results, tuple):
results = (results,)
outputs = ClientSupport.decrypt_result(self._keyset, result)
if not isinstance(outputs, tuple):
outputs = (outputs,)
sanitized_results: List[Union[int, np.ndarray]] = []
sanitized_outputs: List[Union[int, np.ndarray]] = []
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
assert_that("outputs" in client_parameters_json)
output_specs = client_parameters_json["outputs"]
for index, spec in enumerate(output_specs):
n = spec["shape"]["width"]
expected_dtype = (
SignedInteger(n) if self.specs.output_signs[index] else UnsignedInteger(n)
for index, output in enumerate(outputs):
is_signed = self.specs.output_signs[index]
crt_decomposition = (
output_specs[index].get("encryption", {}).get("encoding", {}).get("crt", [])
)
result = results[index] % (2**n)
if expected_dtype.is_signed:
if isinstance(result, int):
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n)
sanitized_results.append(sanititzed_result)
if is_signed:
if crt_decomposition:
if isinstance(output, int):
sanititzed_output = (
output
if output < (int(np.prod(crt_decomposition)) // 2)
else -int(np.prod(crt_decomposition)) + output
)
else:
output = output.astype(np.longlong) # to prevent overflows in numpy
sanititzed_output = np.where(
output < (np.prod(crt_decomposition) // 2),
output,
-np.prod(crt_decomposition) + output,
).astype(np.int64)
sanitized_outputs.append(sanititzed_output)
else:
result = result.astype(np.longlong) # to prevent overflows in numpy
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
sanitized_results.append(sanititzed_result.astype(np.int64))
n = output_specs[index]["shape"]["width"]
output %= 2**n
if isinstance(output, int):
sanititzed_output = output if output < (2 ** (n - 1)) else output - (2**n)
sanitized_outputs.append(sanititzed_output)
else:
output = output.astype(np.longlong) # to prevent overflows in numpy
sanititzed_output = np.where(
output < (2 ** (n - 1)), output, output - (2**n)
).astype(np.int64)
sanitized_outputs.append(sanititzed_output)
else:
sanitized_results.append(
result if isinstance(result, int) else result.astype(np.uint64)
sanitized_outputs.append(
output if isinstance(output, int) else output.astype(np.uint64)
)
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
return sanitized_outputs[0] if len(sanitized_outputs) == 1 else tuple(sanitized_outputs)
@property
def evaluation_keys(self) -> EvaluationKeys:

View File

@@ -28,7 +28,7 @@ from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..values import ClearScalar, EncryptedScalar
from .node_converter import NodeConverter
from .utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
from .utils import MAXIMUM_TLU_BIT_WIDTH
# pylint: enable=no-member,no-name-in-module
@@ -261,14 +261,6 @@ class GraphConverter:
first_tlu_node.location,
]
if first_signed_node is not None and max_bit_width > MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS:
offending_nodes[first_signed_node] = [
f"signed integers are only supported "
f"up to {MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS}-bits "
f"on circuits with table lookups",
first_signed_node.location,
]
if len(offending_nodes) != 0:
raise RuntimeError(
"Function you are trying to compile cannot be converted to MLIR:\n\n"

View File

@@ -14,7 +14,6 @@ from ..internal.utils import assert_that
from ..representation import Node, Operation
MAXIMUM_TLU_BIT_WIDTH = 16
MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS = 8
class HashableNdarray:

View File

@@ -635,6 +635,13 @@ def deterministic_unary_function(x):
},
id="np.expand_dims(x, axis=(0, 1, 2))",
),
pytest.param(
lambda x: x**3,
{
"x": {"status": "encrypted", "range": [-30, 30]},
},
id="x ** 3",
),
],
)
def test_others(function, parameters, helpers):

View File

@@ -399,35 +399,6 @@ Subgraphs:
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
return %5
""", # noqa: E501
),
pytest.param(
lambda x: (10 * np.sin(x + 300)).astype(np.int64),
{"x": "encrypted"},
range(2**10, 2**11),
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR:
%0 = x # EncryptedScalar<uint11> ∈ [1024, 2047]
%1 = 300 # ClearScalar<uint9> ∈ [300, 300]
%2 = add(%0, %1) # EncryptedScalar<uint12> ∈ [1324, 2347]
%3 = subgraph(%2) # EncryptedScalar<int5> ∈ [-9, 9]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups
return %3
Subgraphs:
%3 = subgraph(%2):
%0 = input # EncryptedScalar<uint2>
%1 = sin(%0) # EncryptedScalar<float64>
%2 = 10 # ClearScalar<uint4>
%3 = multiply(%2, %1) # EncryptedScalar<float64>
%4 = astype(%3, dtype=int_) # EncryptedScalar<uint1>
return %4
""", # noqa: E501
),
],