feat: support signed inputs

This commit is contained in:
Umut
2022-05-30 13:01:19 +02:00
parent 4010fc0cbd
commit a6b09ddf09
6 changed files with 150 additions and 20 deletions

View File

@@ -6,7 +6,7 @@ import json
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import (
@@ -125,7 +125,7 @@ class Client:
if len(args) != len(input_specs):
raise ValueError(f"Expected {len(input_specs)} inputs but got {len(args)}")
sanitized_args = {}
sanitized_args: Dict[int, Union[int, np.ndarray]] = {}
for index, spec in enumerate(input_specs):
arg = args[index]
is_valid = isinstance(arg, (int, np.integer)) or (
@@ -136,7 +136,9 @@ class Client:
shape = tuple(spec["shape"]["dimensions"])
is_encrypted = spec["encryption"] is not None
expected_dtype = UnsignedInteger(width)
expected_dtype = (
SignedInteger(width) if self.specs.input_signs[index] else UnsignedInteger(width)
)
expected_value = Value(expected_dtype, shape, is_encrypted)
if is_valid:
expected_min = expected_dtype.min()
@@ -153,7 +155,14 @@ class Client:
)
if is_valid:
sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8)
if isinstance(arg, int) and arg < 0:
sanitized_args[index] = 2 * (expected_max + 1) + arg
else:
sanitized_args[index] = np.where(
arg >= 0,
arg,
2 * (expected_max + 1) + arg,
).astype(np.uint8)
if not is_valid:
actual_value = Value.of(arg, is_encrypted=is_encrypted)

View File

@@ -5,13 +5,22 @@ Declaration of `GraphConverter` class.
# pylint: disable=no-member,no-name-in-module
from copy import deepcopy
from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, Tuple, cast
import concrete.lang as concretelang
import networkx as nx
import numpy as np
from mlir.dialects import builtin
from mlir.ir import Context, InsertionPoint, Location, Module
from concrete.lang.dialects import fhe, fhelinalg
from mlir.dialects import arith, builtin
from mlir.ir import (
Context,
DenseElementsAttr,
InsertionPoint,
IntegerType,
Location,
Module,
RankedTensorType,
)
from ..dtypes import Integer, SignedInteger
from ..internal.utils import assert_that
@@ -58,8 +67,8 @@ class GraphConverter:
elif node.operation == Operation.Input:
assert_that(len(inputs) == 1)
assert_that(inputs[0] == output)
if not isinstance(output.dtype, Integer) or output.dtype.is_signed:
return "only unsigned integer inputs are supported"
if not isinstance(output.dtype, Integer):
return "only integer inputs are supported"
else:
assert_that(node.operation == Operation.Generic)
@@ -283,6 +292,85 @@ class GraphConverter:
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
@staticmethod
def _sanitize_signed_inputs(
graph: Graph,
args: List[Any],
ctx: Context,
) -> Tuple[List[Any], List[str]]:
"""
Apply table lookup to signed inputs in the beginning of evaluation to sanitize them.
Sanitization in this context means to apply a table lookup to obtain proper input values.
"encrypt" method of "Client" class will convert negative inputs to their corresponding
unsigned value in 2s complement representation.
Here is an example for 3 bits:
000 = 0 represents 0
001 = 1 represents 1
010 = 2 represents 2
011 = 3 represents 3
100 = 4 represents -4
101 = 5 represents -3
110 = 6 represents -2
111 = 7 represents -1
And, the following table lookup is applied before anything else to sanitize the inputs:
[0, 1, 2, 3, -4, -3, -2, -1]
Args:
graph (Graph):
computation graph being converted
args (List[Any]):
list of arguments from mlir main
ctx (Context):
mlir context where the conversion is being performed
Returns:
Tuple[List[str], List[Any]]:
sanitized args and name of the sanitized variables in MLIR
"""
sanitized_args = []
arg_mlir_names = []
for i, arg in enumerate(args):
input_node = graph.input_nodes[i]
input_value = input_node.output
assert_that(isinstance(input_value.dtype, Integer))
input_dtype = cast(Integer, input_value.dtype)
if input_dtype.is_signed:
n = input_dtype.bit_width
lut_range = np.arange(2**n)
lut_values = np.where(lut_range < (2 ** (n - 1)), lut_range, lut_range - (2**n))
lut_type = RankedTensorType.get(
(2**n,), IntegerType.get_signless(64, context=ctx)
)
lut_attr = DenseElementsAttr.get(lut_values, context=ctx)
lut = arith.ConstantOp(lut_type, lut_attr).result
resulting_type = NodeConverter.value_to_mlir_type(ctx, input_value)
if input_value.is_scalar:
sanitized = fhe.ApplyLookupTableEintOp(resulting_type, arg, lut).result
else:
sanitized = fhelinalg.ApplyLookupTableEintOp(resulting_type, arg, lut).result
sanitized_args.append(sanitized)
mlir_name = str(sanitized).replace("Value(", "").split("=", maxsplit=1)[0].strip()
else:
sanitized_args.append(arg)
mlir_name = f"%arg{i}"
arg_mlir_names.append(mlir_name)
return sanitized_args, arg_mlir_names
@staticmethod
def convert(graph: Graph, virtual: bool = False) -> str:
"""
@@ -335,12 +423,18 @@ class GraphConverter:
]
@builtin.FuncOp.from_py_func(*parameters)
def main(*arg):
def main(*args):
sanitized_args, arg_mlir_names = GraphConverter._sanitize_signed_inputs(
graph,
args,
ctx,
)
ir_to_mlir = {}
for arg_num, node in graph.input_nodes.items():
ir_to_mlir[node] = arg[arg_num]
ir_to_mlir[node] = sanitized_args[arg_num]
mlir_name = f"%arg{arg_num}"
mlir_name = arg_mlir_names[arg_num]
nodes_to_mlir_names[node] = mlir_name
mlir_names_to_mlir_types[mlir_name] = str(parameters[arg_num])

View File

@@ -203,9 +203,6 @@ class Helpers:
for description in parameters.values():
minimum, maximum = description.get("range", [0, 127])
assert minimum >= 0
assert maximum <= 127
if "shape" in description:
shape = description["shape"]
sample.append(np.random.randint(minimum, maximum + 1, size=shape, dtype=np.int64))

View File

@@ -179,10 +179,34 @@ def test_direct_table_lookup(bits, function, helpers):
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = [np.random.randint(0, 2**bits, size=(3, 2), dtype=np.uint8) for _ in range(100)]
inputset = [np.random.randint(0, 2**bits, size=(3, 2)) for _ in range(100)]
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(0, 2**bits, size=(3, 2), dtype=np.uint8)
sample = np.random.randint(0, 2**bits, size=(3, 2))
helpers.check_execution(circuit, function, sample, retries=10)
# negative scalar
# ---------------
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = range(-(2 ** (bits - 1)), 2 ** (bits - 1))
circuit = compiler.compile(inputset, configuration)
sample = int(np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1)))
helpers.check_execution(circuit, function, sample, retries=10)
# negative tensor
# ---------------
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = [
np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1), size=(3, 2)) for _ in range(100)
]
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1), size=(3, 2))
helpers.check_execution(circuit, function, sample, retries=10)
@@ -209,10 +233,10 @@ def test_direct_multi_table_lookup(helpers):
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = [np.random.randint(0, 2**2, size=(3, 2), dtype=np.uint8) for _ in range(100)]
inputset = [np.random.randint(0, 2**2, size=(3, 2)) for _ in range(100)]
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(0, 2**2, size=(3, 2), dtype=np.uint8)
sample = np.random.randint(0, 2**2, size=(3, 2))
helpers.check_execution(circuit, function, sample, retries=10)

View File

@@ -17,6 +17,12 @@ import concrete.numpy as cnp
{
"x": {"range": [0, 64], "status": "encrypted", "shape": (3, 2)},
},
{
"x": {"range": [-63, 0], "status": "encrypted"},
},
{
"x": {"range": [-63, 0], "status": "encrypted", "shape": (3, 2)},
},
],
)
def test_neg(parameters, helpers):

View File

@@ -43,7 +43,7 @@ return (%2, %3)
Function you are trying to compile cannot be converted to MLIR
%0 = x # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer inputs are supported
%1 = 1.5 # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%2 = multiply(%0, %1) # EncryptedScalar<float64>