mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: support up to 16-bit signed integers with table lookups
This commit is contained in:
@@ -27,7 +27,7 @@ from .extensions import (
|
||||
zero,
|
||||
zeros,
|
||||
)
|
||||
from .mlir.utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
|
||||
from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
from .representation import Graph
|
||||
from .tracing.typing import (
|
||||
f32,
|
||||
|
||||
@@ -197,37 +197,58 @@ class Client:
|
||||
"""
|
||||
|
||||
self.keygen(force=False)
|
||||
results = ClientSupport.decrypt_result(self._keyset, result)
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
outputs = ClientSupport.decrypt_result(self._keyset, result)
|
||||
if not isinstance(outputs, tuple):
|
||||
outputs = (outputs,)
|
||||
|
||||
sanitized_results: List[Union[int, np.ndarray]] = []
|
||||
sanitized_outputs: List[Union[int, np.ndarray]] = []
|
||||
|
||||
client_parameters_json = json.loads(self.specs.client_parameters.serialize())
|
||||
assert_that("outputs" in client_parameters_json)
|
||||
output_specs = client_parameters_json["outputs"]
|
||||
|
||||
for index, spec in enumerate(output_specs):
|
||||
n = spec["shape"]["width"]
|
||||
expected_dtype = (
|
||||
SignedInteger(n) if self.specs.output_signs[index] else UnsignedInteger(n)
|
||||
for index, output in enumerate(outputs):
|
||||
is_signed = self.specs.output_signs[index]
|
||||
crt_decomposition = (
|
||||
output_specs[index].get("encryption", {}).get("encoding", {}).get("crt", [])
|
||||
)
|
||||
|
||||
result = results[index] % (2**n)
|
||||
if expected_dtype.is_signed:
|
||||
if isinstance(result, int):
|
||||
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n)
|
||||
sanitized_results.append(sanititzed_result)
|
||||
if is_signed:
|
||||
if crt_decomposition:
|
||||
if isinstance(output, int):
|
||||
sanititzed_output = (
|
||||
output
|
||||
if output < (int(np.prod(crt_decomposition)) // 2)
|
||||
else -int(np.prod(crt_decomposition)) + output
|
||||
)
|
||||
else:
|
||||
output = output.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_output = np.where(
|
||||
output < (np.prod(crt_decomposition) // 2),
|
||||
output,
|
||||
-np.prod(crt_decomposition) + output,
|
||||
).astype(np.int64)
|
||||
|
||||
sanitized_outputs.append(sanititzed_output)
|
||||
|
||||
else:
|
||||
result = result.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
|
||||
sanitized_results.append(sanititzed_result.astype(np.int64))
|
||||
n = output_specs[index]["shape"]["width"]
|
||||
output %= 2**n
|
||||
if isinstance(output, int):
|
||||
sanititzed_output = output if output < (2 ** (n - 1)) else output - (2**n)
|
||||
sanitized_outputs.append(sanititzed_output)
|
||||
else:
|
||||
output = output.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_output = np.where(
|
||||
output < (2 ** (n - 1)), output, output - (2**n)
|
||||
).astype(np.int64)
|
||||
sanitized_outputs.append(sanititzed_output)
|
||||
else:
|
||||
sanitized_results.append(
|
||||
result if isinstance(result, int) else result.astype(np.uint64)
|
||||
sanitized_outputs.append(
|
||||
output if isinstance(output, int) else output.astype(np.uint64)
|
||||
)
|
||||
|
||||
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
|
||||
return sanitized_outputs[0] if len(sanitized_outputs) == 1 else tuple(sanitized_outputs)
|
||||
|
||||
@property
|
||||
def evaluation_keys(self) -> EvaluationKeys:
|
||||
|
||||
@@ -28,7 +28,7 @@ from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import ClearScalar, EncryptedScalar
|
||||
from .node_converter import NodeConverter
|
||||
from .utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH
|
||||
from .utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
|
||||
# pylint: enable=no-member,no-name-in-module
|
||||
|
||||
@@ -261,14 +261,6 @@ class GraphConverter:
|
||||
first_tlu_node.location,
|
||||
]
|
||||
|
||||
if first_signed_node is not None and max_bit_width > MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS:
|
||||
offending_nodes[first_signed_node] = [
|
||||
f"signed integers are only supported "
|
||||
f"up to {MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS}-bits "
|
||||
f"on circuits with table lookups",
|
||||
first_signed_node.location,
|
||||
]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
"Function you are trying to compile cannot be converted to MLIR:\n\n"
|
||||
|
||||
@@ -14,7 +14,6 @@ from ..internal.utils import assert_that
|
||||
from ..representation import Node, Operation
|
||||
|
||||
MAXIMUM_TLU_BIT_WIDTH = 16
|
||||
MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS = 8
|
||||
|
||||
|
||||
class HashableNdarray:
|
||||
|
||||
@@ -635,6 +635,13 @@ def deterministic_unary_function(x):
|
||||
},
|
||||
id="np.expand_dims(x, axis=(0, 1, 2))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x**3,
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-30, 30]},
|
||||
},
|
||||
id="x ** 3",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_others(function, parameters, helpers):
|
||||
|
||||
@@ -399,35 +399,6 @@ Subgraphs:
|
||||
%5 = astype(%4, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %5
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (10 * np.sin(x + 300)).astype(np.int64),
|
||||
{"x": "encrypted"},
|
||||
range(2**10, 2**11),
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR:
|
||||
|
||||
%0 = x # EncryptedScalar<uint11> ∈ [1024, 2047]
|
||||
%1 = 300 # ClearScalar<uint9> ∈ [300, 300]
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint12> ∈ [1324, 2347]
|
||||
%3 = subgraph(%2) # EncryptedScalar<int5> ∈ [-9, 9]
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%3 = subgraph(%2):
|
||||
|
||||
%0 = input # EncryptedScalar<uint2>
|
||||
%1 = sin(%0) # EncryptedScalar<float64>
|
||||
%2 = 10 # ClearScalar<uint4>
|
||||
%3 = multiply(%2, %1) # EncryptedScalar<float64>
|
||||
%4 = astype(%3, dtype=int_) # EncryptedScalar<uint1>
|
||||
return %4
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user