mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: support signed inputs
This commit is contained in:
@@ -6,7 +6,7 @@ import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import (
|
||||
@@ -125,7 +125,7 @@ class Client:
|
||||
if len(args) != len(input_specs):
|
||||
raise ValueError(f"Expected {len(input_specs)} inputs but got {len(args)}")
|
||||
|
||||
sanitized_args = {}
|
||||
sanitized_args: Dict[int, Union[int, np.ndarray]] = {}
|
||||
for index, spec in enumerate(input_specs):
|
||||
arg = args[index]
|
||||
is_valid = isinstance(arg, (int, np.integer)) or (
|
||||
@@ -136,7 +136,9 @@ class Client:
|
||||
shape = tuple(spec["shape"]["dimensions"])
|
||||
is_encrypted = spec["encryption"] is not None
|
||||
|
||||
expected_dtype = UnsignedInteger(width)
|
||||
expected_dtype = (
|
||||
SignedInteger(width) if self.specs.input_signs[index] else UnsignedInteger(width)
|
||||
)
|
||||
expected_value = Value(expected_dtype, shape, is_encrypted)
|
||||
if is_valid:
|
||||
expected_min = expected_dtype.min()
|
||||
@@ -153,7 +155,14 @@ class Client:
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8)
|
||||
if isinstance(arg, int) and arg < 0:
|
||||
sanitized_args[index] = 2 * (expected_max + 1) + arg
|
||||
else:
|
||||
sanitized_args[index] = np.where(
|
||||
arg >= 0,
|
||||
arg,
|
||||
2 * (expected_max + 1) + arg,
|
||||
).astype(np.uint8)
|
||||
|
||||
if not is_valid:
|
||||
actual_value = Value.of(arg, is_encrypted=is_encrypted)
|
||||
|
||||
@@ -5,13 +5,22 @@ Declaration of `GraphConverter` class.
|
||||
# pylint: disable=no-member,no-name-in-module
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import concrete.lang as concretelang
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from mlir.dialects import builtin
|
||||
from mlir.ir import Context, InsertionPoint, Location, Module
|
||||
from concrete.lang.dialects import fhe, fhelinalg
|
||||
from mlir.dialects import arith, builtin
|
||||
from mlir.ir import (
|
||||
Context,
|
||||
DenseElementsAttr,
|
||||
InsertionPoint,
|
||||
IntegerType,
|
||||
Location,
|
||||
Module,
|
||||
RankedTensorType,
|
||||
)
|
||||
|
||||
from ..dtypes import Integer, SignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
@@ -58,8 +67,8 @@ class GraphConverter:
|
||||
elif node.operation == Operation.Input:
|
||||
assert_that(len(inputs) == 1)
|
||||
assert_that(inputs[0] == output)
|
||||
if not isinstance(output.dtype, Integer) or output.dtype.is_signed:
|
||||
return "only unsigned integer inputs are supported"
|
||||
if not isinstance(output.dtype, Integer):
|
||||
return "only integer inputs are supported"
|
||||
|
||||
else:
|
||||
assert_that(node.operation == Operation.Generic)
|
||||
@@ -283,6 +292,85 @@ class GraphConverter:
|
||||
|
||||
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_signed_inputs(
|
||||
graph: Graph,
|
||||
args: List[Any],
|
||||
ctx: Context,
|
||||
) -> Tuple[List[Any], List[str]]:
|
||||
"""
|
||||
Apply table lookup to signed inputs in the beginning of evaluation to sanitize them.
|
||||
|
||||
Sanitization in this context means to apply a table lookup to obtain proper input values.
|
||||
|
||||
"encrypt" method of "Client" class will convert negative inputs to their corresponding
|
||||
unsigned value in 2s complement representation.
|
||||
|
||||
Here is an example for 3 bits:
|
||||
000 = 0 represents 0
|
||||
001 = 1 represents 1
|
||||
010 = 2 represents 2
|
||||
011 = 3 represents 3
|
||||
100 = 4 represents -4
|
||||
101 = 5 represents -3
|
||||
110 = 6 represents -2
|
||||
111 = 7 represents -1
|
||||
|
||||
And, the following table lookup is applied before anything else to sanitize the inputs:
|
||||
[0, 1, 2, 3, -4, -3, -2, -1]
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph being converted
|
||||
|
||||
args (List[Any]):
|
||||
list of arguments from mlir main
|
||||
|
||||
ctx (Context):
|
||||
mlir context where the conversion is being performed
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], List[Any]]:
|
||||
sanitized args and name of the sanitized variables in MLIR
|
||||
"""
|
||||
|
||||
sanitized_args = []
|
||||
arg_mlir_names = []
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
input_node = graph.input_nodes[i]
|
||||
input_value = input_node.output
|
||||
|
||||
assert_that(isinstance(input_value.dtype, Integer))
|
||||
input_dtype = cast(Integer, input_value.dtype)
|
||||
|
||||
if input_dtype.is_signed:
|
||||
n = input_dtype.bit_width
|
||||
lut_range = np.arange(2**n)
|
||||
|
||||
lut_values = np.where(lut_range < (2 ** (n - 1)), lut_range, lut_range - (2**n))
|
||||
lut_type = RankedTensorType.get(
|
||||
(2**n,), IntegerType.get_signless(64, context=ctx)
|
||||
)
|
||||
lut_attr = DenseElementsAttr.get(lut_values, context=ctx)
|
||||
lut = arith.ConstantOp(lut_type, lut_attr).result
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value)
|
||||
if input_value.is_scalar:
|
||||
sanitized = fhe.ApplyLookupTableEintOp(resulting_type, arg, lut).result
|
||||
else:
|
||||
sanitized = fhelinalg.ApplyLookupTableEintOp(resulting_type, arg, lut).result
|
||||
|
||||
sanitized_args.append(sanitized)
|
||||
mlir_name = str(sanitized).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
else:
|
||||
sanitized_args.append(arg)
|
||||
mlir_name = f"%arg{i}"
|
||||
|
||||
arg_mlir_names.append(mlir_name)
|
||||
|
||||
return sanitized_args, arg_mlir_names
|
||||
|
||||
@staticmethod
|
||||
def convert(graph: Graph, virtual: bool = False) -> str:
|
||||
"""
|
||||
@@ -335,12 +423,18 @@ class GraphConverter:
|
||||
]
|
||||
|
||||
@builtin.FuncOp.from_py_func(*parameters)
|
||||
def main(*arg):
|
||||
def main(*args):
|
||||
sanitized_args, arg_mlir_names = GraphConverter._sanitize_signed_inputs(
|
||||
graph,
|
||||
args,
|
||||
ctx,
|
||||
)
|
||||
|
||||
ir_to_mlir = {}
|
||||
for arg_num, node in graph.input_nodes.items():
|
||||
ir_to_mlir[node] = arg[arg_num]
|
||||
ir_to_mlir[node] = sanitized_args[arg_num]
|
||||
|
||||
mlir_name = f"%arg{arg_num}"
|
||||
mlir_name = arg_mlir_names[arg_num]
|
||||
nodes_to_mlir_names[node] = mlir_name
|
||||
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])
|
||||
|
||||
|
||||
@@ -203,9 +203,6 @@ class Helpers:
|
||||
for description in parameters.values():
|
||||
minimum, maximum = description.get("range", [0, 127])
|
||||
|
||||
assert minimum >= 0
|
||||
assert maximum <= 127
|
||||
|
||||
if "shape" in description:
|
||||
shape = description["shape"]
|
||||
sample.append(np.random.randint(minimum, maximum + 1, size=shape, dtype=np.int64))
|
||||
|
||||
@@ -179,10 +179,34 @@ def test_direct_table_lookup(bits, function, helpers):
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2**bits, size=(3, 2), dtype=np.uint8) for _ in range(100)]
|
||||
inputset = [np.random.randint(0, 2**bits, size=(3, 2)) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2**bits, size=(3, 2), dtype=np.uint8)
|
||||
sample = np.random.randint(0, 2**bits, size=(3, 2))
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
|
||||
# negative scalar
|
||||
# ---------------
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = range(-(2 ** (bits - 1)), 2 ** (bits - 1))
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = int(np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1)))
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
|
||||
# negative tensor
|
||||
# ---------------
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [
|
||||
np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1), size=(3, 2)) for _ in range(100)
|
||||
]
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1), size=(3, 2))
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
|
||||
|
||||
@@ -209,10 +233,10 @@ def test_direct_multi_table_lookup(helpers):
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2**2, size=(3, 2), dtype=np.uint8) for _ in range(100)]
|
||||
inputset = [np.random.randint(0, 2**2, size=(3, 2)) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2**2, size=(3, 2), dtype=np.uint8)
|
||||
sample = np.random.randint(0, 2**2, size=(3, 2))
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,12 @@ import concrete.numpy as cnp
|
||||
{
|
||||
"x": {"range": [0, 64], "status": "encrypted", "shape": (3, 2)},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-63, 0], "status": "encrypted"},
|
||||
},
|
||||
{
|
||||
"x": {"range": [-63, 0], "status": "encrypted", "shape": (3, 2)},
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_neg(parameters, helpers):
|
||||
|
||||
@@ -43,7 +43,7 @@ return (%2, %3)
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer inputs are supported
|
||||
%1 = 1.5 # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%2 = multiply(%0, %1) # EncryptedScalar<float64>
|
||||
|
||||
Reference in New Issue
Block a user