|
|
|
|
@@ -3,100 +3,118 @@ Declaration of `tfhers.Bridge` class.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# pylint: disable=import-error,no-member,no-name-in-module
|
|
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
from concrete.compiler import LweSecretKey, TfhersExporter, TfhersFheIntDescription
|
|
|
|
|
|
|
|
|
|
from concrete import fhe
|
|
|
|
|
import concrete.fhe as fhe
|
|
|
|
|
from concrete.fhe.compilation.value import Value
|
|
|
|
|
|
|
|
|
|
from .dtypes import EncryptionKeyChoice, TFHERSIntegerType
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Bridge:
|
|
|
|
|
"""TFHErs Bridge extend a Circuit with TFHErs functionalities.
|
|
|
|
|
"""TFHErs Bridge extend an Module with TFHErs functionalities.
|
|
|
|
|
|
|
|
|
|
input_types (List[Optional[TFHERSIntegerType]]): maps every input to a type. None means
|
|
|
|
|
a non-tfhers type
|
|
|
|
|
output_types (List[Optional[TFHERSIntegerType]]): maps every output to a type. None means
|
|
|
|
|
a non-tfhers type
|
|
|
|
|
input_shapes (List[Optional[Tuple[int, ...]]]): maps every input to a shape. None means
|
|
|
|
|
a non-tfhers type
|
|
|
|
|
output_shapes (List[Optional[Tuple[int, ...]]]): maps every output to a shape. None means
|
|
|
|
|
a non-tfhers type
|
|
|
|
|
input_types_per_func (Dict[str, List[Optional[TFHERSIntegerType]]]):
|
|
|
|
|
maps every input to a type for every function in the module. None means a non-tfhers type
|
|
|
|
|
output_types_per_func (Dict[str, List[Optional[TFHERSIntegerType]]]):
|
|
|
|
|
maps every output to a type for every function in the module. None means a non-tfhers type
|
|
|
|
|
input_shapes_per_func (Dict[str, List[Optional[Tuple[int, ...]]]]):
|
|
|
|
|
maps every input to a shape for every function in the module. None means a non-tfhers type
|
|
|
|
|
output_shapes_per_func (Dict[str, List[Optional[Tuple[int, ...]]]]):
|
|
|
|
|
maps every output to a shape for every function in the module. None means a non-tfhers type
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
circuit: "fhe.Circuit"
|
|
|
|
|
input_types: List[Optional[TFHERSIntegerType]]
|
|
|
|
|
output_types: List[Optional[TFHERSIntegerType]]
|
|
|
|
|
input_shapes: List[Optional[Tuple[int, ...]]]
|
|
|
|
|
output_shapes: List[Optional[Tuple[int, ...]]]
|
|
|
|
|
module: "fhe.Module"
|
|
|
|
|
default_function: Optional[str]
|
|
|
|
|
input_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]]
|
|
|
|
|
output_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]]
|
|
|
|
|
input_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]]
|
|
|
|
|
output_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]]
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
circuit: "fhe.Circuit",
|
|
|
|
|
input_types: List[Optional[TFHERSIntegerType]],
|
|
|
|
|
output_types: List[Optional[TFHERSIntegerType]],
|
|
|
|
|
input_shapes: List[Optional[Tuple[int, ...]]],
|
|
|
|
|
output_shapes: List[Optional[Tuple[int, ...]]],
|
|
|
|
|
module: "fhe.Module",
|
|
|
|
|
input_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]],
|
|
|
|
|
output_types_per_func: Dict[str, List[Optional[TFHERSIntegerType]]],
|
|
|
|
|
input_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]],
|
|
|
|
|
output_shapes_per_func: Dict[str, List[Optional[Tuple[int, ...]]]],
|
|
|
|
|
):
|
|
|
|
|
self.circuit = circuit
|
|
|
|
|
self.input_types = input_types
|
|
|
|
|
self.output_types = output_types
|
|
|
|
|
self.input_shapes = input_shapes
|
|
|
|
|
self.output_shapes = output_shapes
|
|
|
|
|
if module.function_count == 1:
|
|
|
|
|
self.default_function = next(iter(module.graphs.keys()))
|
|
|
|
|
else:
|
|
|
|
|
self.default_function = None
|
|
|
|
|
self.module = module
|
|
|
|
|
self.input_types_per_func = input_types_per_func
|
|
|
|
|
self.output_types_per_func = output_types_per_func
|
|
|
|
|
self.input_shapes_per_func = input_shapes_per_func
|
|
|
|
|
self.output_shapes_per_func = output_shapes_per_func
|
|
|
|
|
|
|
|
|
|
def _input_type(self, input_idx: int) -> Optional[TFHERSIntegerType]:
|
|
|
|
|
def _get_default_func_or_raise_error(self, calling_func: str) -> str:
|
|
|
|
|
if self.default_function is not None:
|
|
|
|
|
return self.default_function
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Module contains more than one function, so please provide 'func_name' while "
|
|
|
|
|
f"calling '{calling_func}'"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _input_type(self, func_name: str, input_idx: int) -> Optional[TFHERSIntegerType]:
|
|
|
|
|
"""Return the type of a certain input.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
func_name (str): name of the function the input belongs to
|
|
|
|
|
input_idx (int): the input index to get the type of
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Optional[TFHERSIntegerType]: input type. None means a non-tfhers type
|
|
|
|
|
"""
|
|
|
|
|
return self.input_types[input_idx]
|
|
|
|
|
return self.input_types_per_func[func_name][input_idx]
|
|
|
|
|
|
|
|
|
|
def _output_type(self, output_idx: int) -> Optional[TFHERSIntegerType]:
|
|
|
|
|
def _output_type(self, func_name: str, output_idx: int) -> Optional[TFHERSIntegerType]:
|
|
|
|
|
"""Return the type of a certain output.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
func_name (str): name of the function the output belongs to
|
|
|
|
|
output_idx (int): the output index to get the type of
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Optional[TFHERSIntegerType]: output type. None means a non-tfhers type
|
|
|
|
|
"""
|
|
|
|
|
return self.output_types[output_idx]
|
|
|
|
|
return self.output_types_per_func[func_name][output_idx]
|
|
|
|
|
|
|
|
|
|
def _input_shape(self, input_idx: int) -> Optional[Tuple[int, ...]]:
|
|
|
|
|
def _input_shape(self, func_name: str, input_idx: int) -> Optional[Tuple[int, ...]]:
|
|
|
|
|
"""Return the shape of a certain input.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
func_name (str): name of the function the input belongs to
|
|
|
|
|
input_idx (int): the input index to get the shape of
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Optional[Tuple[int, ...]]: input shape. None means a non-tfhers type
|
|
|
|
|
"""
|
|
|
|
|
return self.input_shapes[input_idx]
|
|
|
|
|
return self.input_shapes_per_func[func_name][input_idx]
|
|
|
|
|
|
|
|
|
|
def _output_shape(self, output_idx: int) -> Optional[Tuple[int, ...]]: # pragma: no cover
|
|
|
|
|
def _output_shape(
|
|
|
|
|
self, func_name: str, output_idx: int
|
|
|
|
|
) -> Optional[Tuple[int, ...]]: # pragma: no cover
|
|
|
|
|
"""Return the shape of a certain output.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
func_name (str): name of the function the output belongs to
|
|
|
|
|
output_idx (int): the output index to get the shape of
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Optional[Tuple[int, ...]]: output shape. None means a non-tfhers type
|
|
|
|
|
"""
|
|
|
|
|
return self.output_shapes[output_idx]
|
|
|
|
|
return self.output_shapes_per_func[func_name][output_idx]
|
|
|
|
|
|
|
|
|
|
def _input_keyid(self, input_idx: int) -> int:
|
|
|
|
|
return self.circuit.client.specs.program_info.input_keyid_at(
|
|
|
|
|
input_idx, self.circuit.function_name
|
|
|
|
|
)
|
|
|
|
|
def _input_keyid(self, func_name: str, input_idx: int) -> int:
|
|
|
|
|
return self.module.client.specs.program_info.input_keyid_at(input_idx, func_name)
|
|
|
|
|
|
|
|
|
|
def _input_variance(self, input_idx: int) -> float:
|
|
|
|
|
input_type = self._input_type(input_idx)
|
|
|
|
|
def _input_variance(self, func_name: str, input_idx: int) -> float:
|
|
|
|
|
input_type = self._input_type(func_name, input_idx)
|
|
|
|
|
if input_type is None: # pragma: no cover
|
|
|
|
|
msg = "input at 'input_idx' is not a TFHErs value"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
@@ -133,38 +151,48 @@ class Bridge:
|
|
|
|
|
ks_first,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def import_value(self, buffer: bytes, input_idx: int) -> Value:
|
|
|
|
|
def import_value(self, buffer: bytes, input_idx: int, func_name: Optional[str] = None) -> Value:
|
|
|
|
|
"""Import a serialized TFHErs integer as a Value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
buffer (bytes): serialized integer
|
|
|
|
|
input_idx (int): the index of the input expecting this value
|
|
|
|
|
func_name (Optional[str]): name of the function the value belongs to.
|
|
|
|
|
Doesn't need to be provided if there is a single function.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
fhe.TransportValue: imported value
|
|
|
|
|
"""
|
|
|
|
|
input_type = self._input_type(input_idx)
|
|
|
|
|
input_shape = self._input_shape(input_idx)
|
|
|
|
|
if func_name is None:
|
|
|
|
|
func_name = self._get_default_func_or_raise_error("import_value")
|
|
|
|
|
|
|
|
|
|
input_type = self._input_type(func_name, input_idx)
|
|
|
|
|
input_shape = self._input_shape(func_name, input_idx)
|
|
|
|
|
if input_type is None or input_shape is None: # pragma: no cover
|
|
|
|
|
msg = "input at 'input_idx' is not a TFHErs value"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
fheint_desc = self._description_from_type(input_type)
|
|
|
|
|
keyid = self._input_keyid(input_idx)
|
|
|
|
|
variance = self._input_variance(input_idx)
|
|
|
|
|
keyid = self._input_keyid(func_name, input_idx)
|
|
|
|
|
variance = self._input_variance(func_name, input_idx)
|
|
|
|
|
return Value(TfhersExporter.import_int(buffer, fheint_desc, keyid, variance, input_shape))
|
|
|
|
|
|
|
|
|
|
def export_value(self, value: Value, output_idx: int) -> bytes:
|
|
|
|
|
def export_value(self, value: Value, output_idx: int, func_name: Optional[str] = None) -> bytes:
|
|
|
|
|
"""Export a value as a serialized TFHErs integer.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
value (TransportValue): value to export
|
|
|
|
|
output_idx (int): the index corresponding to this output
|
|
|
|
|
func_name (Optional[str]): name of the function the value belongs to.
|
|
|
|
|
Doesn't need to be provided if there is a single function.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bytes: serialized fheuint8
|
|
|
|
|
"""
|
|
|
|
|
output_type = self._output_type(output_idx)
|
|
|
|
|
if func_name is None:
|
|
|
|
|
func_name = self._get_default_func_or_raise_error("export_value")
|
|
|
|
|
|
|
|
|
|
output_type = self._output_type(func_name, output_idx)
|
|
|
|
|
if output_type is None: # pragma: no cover
|
|
|
|
|
msg = "output at 'output_idx' is not a TFHErs value"
|
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
@@ -174,18 +202,23 @@ class Bridge:
|
|
|
|
|
value._inner, fheint_desc # pylint: disable=protected-access
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def serialize_input_secret_key(self, input_idx: int) -> bytes:
|
|
|
|
|
def serialize_input_secret_key(self, input_idx: int, func_name: Optional[str] = None) -> bytes:
|
|
|
|
|
"""Serialize secret key used for a specific input.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input_idx (int): input index corresponding to the key to serialize
|
|
|
|
|
func_name (Optional[str]): name of the function the key belongs to.
|
|
|
|
|
Doesn't need to be provided if there is a single function.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bytes: serialized key
|
|
|
|
|
"""
|
|
|
|
|
keyid = self._input_keyid(input_idx)
|
|
|
|
|
if func_name is None:
|
|
|
|
|
func_name = self._get_default_func_or_raise_error("serialize_input_secret_key")
|
|
|
|
|
|
|
|
|
|
keyid = self._input_keyid(func_name, input_idx)
|
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
|
keys = self.circuit.client.keys
|
|
|
|
|
keys = self.module.client.keys
|
|
|
|
|
assert keys is not None
|
|
|
|
|
secret_key = keys._keyset.get_client_keys().get_secret_keys()[keyid] # type: ignore
|
|
|
|
|
# pylint: enable=protected-access
|
|
|
|
|
@@ -193,7 +226,7 @@ class Bridge:
|
|
|
|
|
|
|
|
|
|
def keygen_with_initial_keys(
|
|
|
|
|
self,
|
|
|
|
|
input_idx_to_key_buffer: Dict[int, bytes],
|
|
|
|
|
input_idx_to_key_buffer: Dict[Union[Tuple[str, int], int], bytes],
|
|
|
|
|
force: bool = False,
|
|
|
|
|
seed: Optional[int] = None,
|
|
|
|
|
encryption_seed: Optional[int] = None,
|
|
|
|
|
@@ -210,30 +243,45 @@ class Bridge:
|
|
|
|
|
encryption_seed (Optional[int], default = None):
|
|
|
|
|
seed for encryption randomness
|
|
|
|
|
|
|
|
|
|
input_idx_to_key_buffer (Dict[int, bytes]): initial keys to set before keygen
|
|
|
|
|
input_idx_to_key_buffer (Dict[Union[Tuple[str, int], int], bytes]):
|
|
|
|
|
initial keys to set before keygen. Two possible formats: the first is when you have
|
|
|
|
|
a single function. Here you can just provide the position of the input as index.
|
|
|
|
|
The second is when you have multiple functions. You will need to provide both the
|
|
|
|
|
name of the function and the input's position as index.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
RuntimeError: if failed to deserialize the key
|
|
|
|
|
"""
|
|
|
|
|
initial_keys: Dict[int, LweSecretKey] = {}
|
|
|
|
|
for input_idx in input_idx_to_key_buffer:
|
|
|
|
|
key_id = self._input_keyid(input_idx)
|
|
|
|
|
for idx in input_idx_to_key_buffer:
|
|
|
|
|
if isinstance(idx, tuple):
|
|
|
|
|
func_name, input_idx = idx
|
|
|
|
|
elif isinstance(idx, int) and self.default_function is not None:
|
|
|
|
|
input_idx = idx
|
|
|
|
|
func_name = self.default_function
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Module contains more than one function, so please make sure to mention "
|
|
|
|
|
"the function name (not just the position) in input_idx_to_key_buffer. "
|
|
|
|
|
"An example index would be a tuple ('my_func', 1)."
|
|
|
|
|
)
|
|
|
|
|
key_id = self._input_keyid(func_name, input_idx)
|
|
|
|
|
# no need to deserialize the same key again
|
|
|
|
|
if key_id in initial_keys: # pragma: no cover
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
key_buffer = input_idx_to_key_buffer[input_idx]
|
|
|
|
|
param = self.circuit.client.specs.program_info.get_keyset_info().secret_keys()[key_id]
|
|
|
|
|
key_buffer = input_idx_to_key_buffer[idx]
|
|
|
|
|
param = self.module.client.specs.program_info.get_keyset_info().secret_keys()[key_id]
|
|
|
|
|
try:
|
|
|
|
|
initial_keys[key_id] = LweSecretKey.deserialize(key_buffer, param)
|
|
|
|
|
except Exception as e: # pragma: no cover
|
|
|
|
|
msg = (
|
|
|
|
|
f"failed deserializing key for input with index {input_idx}. Make sure the key"
|
|
|
|
|
f"failed deserializing key for input with index {idx}. Make sure the key"
|
|
|
|
|
" is for the right input"
|
|
|
|
|
)
|
|
|
|
|
raise RuntimeError(msg) from e
|
|
|
|
|
|
|
|
|
|
self.circuit.keygen(
|
|
|
|
|
self.module.keygen(
|
|
|
|
|
force=force,
|
|
|
|
|
seed=seed,
|
|
|
|
|
encryption_seed=encryption_seed,
|
|
|
|
|
@@ -241,33 +289,57 @@ class Bridge:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def new_bridge(circuit: "fhe.Circuit") -> Bridge:
|
|
|
|
|
"""Create a TFHErs bridge from a circuit.
|
|
|
|
|
def new_bridge(circuit_or_module: Union["fhe.Circuit", "fhe.Module"]) -> Bridge:
|
|
|
|
|
"""Create a TFHErs bridge from a circuit or module.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
circuit (Circuit): compiled circuit
|
|
|
|
|
circuit (Union[Circuit, Module]): compiled circuit or module
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Bridge: TFHErs bridge
|
|
|
|
|
"""
|
|
|
|
|
input_types: List[Optional[TFHERSIntegerType]] = []
|
|
|
|
|
input_shapes: List[Optional[Tuple[int, ...]]] = []
|
|
|
|
|
for input_node in circuit.graph.ordered_inputs():
|
|
|
|
|
if isinstance(input_node.output.dtype, TFHERSIntegerType):
|
|
|
|
|
input_types.append(input_node.output.dtype)
|
|
|
|
|
input_shapes.append(input_node.output.shape)
|
|
|
|
|
else:
|
|
|
|
|
input_types.append(None)
|
|
|
|
|
input_shapes.append(None)
|
|
|
|
|
if isinstance(circuit_or_module, fhe.Module):
|
|
|
|
|
module = circuit_or_module
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(circuit_or_module, fhe.Circuit)
|
|
|
|
|
module = circuit_or_module._module
|
|
|
|
|
|
|
|
|
|
output_types: List[Optional[TFHERSIntegerType]] = []
|
|
|
|
|
output_shapes: List[Optional[Tuple[int, ...]]] = []
|
|
|
|
|
for output_node in circuit.graph.ordered_outputs():
|
|
|
|
|
if isinstance(output_node.output.dtype, TFHERSIntegerType):
|
|
|
|
|
output_types.append(output_node.output.dtype)
|
|
|
|
|
output_shapes.append(output_node.output.shape)
|
|
|
|
|
else: # pragma: no cover
|
|
|
|
|
output_types.append(None)
|
|
|
|
|
output_shapes.append(None)
|
|
|
|
|
input_types_per_func = {}
|
|
|
|
|
output_types_per_func = {}
|
|
|
|
|
input_shapes_per_func = {}
|
|
|
|
|
output_shapes_per_func = {}
|
|
|
|
|
|
|
|
|
|
return Bridge(circuit, input_types, output_types, input_shapes, output_shapes)
|
|
|
|
|
for func_name, graph in module.graphs.items():
|
|
|
|
|
input_types: List[Optional[TFHERSIntegerType]] = []
|
|
|
|
|
input_shapes: List[Optional[Tuple[int, ...]]] = []
|
|
|
|
|
for input_node in graph.ordered_inputs():
|
|
|
|
|
if isinstance(input_node.output.dtype, TFHERSIntegerType):
|
|
|
|
|
input_types.append(input_node.output.dtype)
|
|
|
|
|
input_shapes.append(input_node.output.shape)
|
|
|
|
|
else:
|
|
|
|
|
input_types.append(None)
|
|
|
|
|
input_shapes.append(None)
|
|
|
|
|
|
|
|
|
|
input_types_per_func[func_name] = input_types
|
|
|
|
|
input_shapes_per_func[func_name] = input_shapes
|
|
|
|
|
|
|
|
|
|
output_types: List[Optional[TFHERSIntegerType]] = []
|
|
|
|
|
output_shapes: List[Optional[Tuple[int, ...]]] = []
|
|
|
|
|
for output_node in graph.ordered_outputs():
|
|
|
|
|
if isinstance(output_node.output.dtype, TFHERSIntegerType):
|
|
|
|
|
output_types.append(output_node.output.dtype)
|
|
|
|
|
output_shapes.append(output_node.output.shape)
|
|
|
|
|
else: # pragma: no cover
|
|
|
|
|
output_types.append(None)
|
|
|
|
|
output_shapes.append(None)
|
|
|
|
|
|
|
|
|
|
output_types_per_func[func_name] = output_types
|
|
|
|
|
output_shapes_per_func[func_name] = output_shapes
|
|
|
|
|
|
|
|
|
|
return Bridge(
|
|
|
|
|
module,
|
|
|
|
|
input_types_per_func,
|
|
|
|
|
output_types_per_func,
|
|
|
|
|
input_shapes_per_func,
|
|
|
|
|
output_shapes_per_func,
|
|
|
|
|
)
|
|
|
|
|
|