refactor(frontend): remove the need to pass the in/out types in Bridge

This commit is contained in:
youben11
2024-09-11 12:37:55 +01:00
committed by Quentin Bourgerie
parent eb72bbc53c
commit d63ff516af
2 changed files with 37 additions and 26 deletions

View File

@@ -12,18 +12,24 @@ from .dtypes import EncryptionKeyChoice, TFHERSIntegerType
class Bridge:
"""TFHErs Bridge extend a Circuit with TFHErs functionalities."""
"""TFHErs Bridge extend a Circuit with TFHErs functionalities.
input_types (List[Optional[TFHERSIntegerType]]): maps every input to a type. None means
a non-tfhers type
output_types (List[Optional[TFHERSIntegerType]]): maps every output to a type. None means
a non-tfhers type
"""
circuit: "fhe.Circuit"
input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]
output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]
input_types: List[Optional[TFHERSIntegerType]]
output_types: List[Optional[TFHERSIntegerType]]
func_name: str
def __init__(
self,
circuit: "fhe.Circuit",
input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
input_types: List[Optional[TFHERSIntegerType]],
output_types: List[Optional[TFHERSIntegerType]],
func_name: str,
):
self.circuit = circuit
@@ -40,9 +46,7 @@ class Bridge:
Returns:
Optional[TFHERSIntegerType]: input type. None means a non-tfhers type
"""
if isinstance(self.input_types, list):
return self.input_types[input_idx] # pragma: no cover
return self.input_types
return self.input_types[input_idx]
def _output_type(self, output_idx: int) -> Optional[TFHERSIntegerType]:
"""Return the type of a certain output.
@@ -53,9 +57,7 @@ class Bridge:
Returns:
Optional[TFHERSIntegerType]: output type. None means a non-tfhers type
"""
if isinstance(self.output_types, list):
return self.output_types[output_idx] # pragma: no cover
return self.output_types
return self.output_types[output_idx]
def _input_keyid(self, input_idx: int) -> int:
return self.circuit.client.specs.client_parameters.input_keyid_at(input_idx, self.func_name)
@@ -230,23 +232,32 @@ class Bridge:
def new_bridge(
circuit: "fhe.Circuit",
input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
func_name,
func_name: str = "main",
) -> Bridge:
"""Create a TFHErs bridge from a circuit.
Args:
circuit (Circuit): compiled circuit
input_types (Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]): lists
should map every input to a type, while a single element is general for all inputs.
None means a non-tfhers type
output_types (Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]): lists
should map every output to a type, while a single element is general for all outputs.
None means a non-tfhers type
func_name (str): name of the function to use.
func_name (str, optional): name of the function to use. Defaults to "main".
Returns:
Bridge: TFHErs bridge
"""
input_types = [
(
input_node.output.dtype
if isinstance(input_node.output.dtype, TFHERSIntegerType)
else None
)
for input_node in circuit.graph.ordered_inputs()
]
output_types = [
(
output_node.output.dtype
if isinstance(output_node.output.dtype, TFHERSIntegerType)
else None
)
for output_node in circuit.graph.ordered_outputs()
]
return Bridge(circuit, input_types, output_types, func_name)

View File

@@ -381,8 +381,8 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen(
concrete_encoded_result = circuit.encrypt_run_decrypt(*concrete_encoded_sample)
assert (dtype.decode(concrete_encoded_result) == function(*sample)).all()
###### TFHErs Encryption ######################################################
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="<lambda>")
###### TFHErs Encryption & Computation ########################################
tfhers_bridge = tfhers.new_bridge(circuit, func_name="main")
# serialize key
_, key_path = tempfile.mkstemp()
@@ -617,7 +617,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen(
assert (dtype.decode(concrete_encoded_result) == function(*sample)).all()
###### TFHErs Encryption ######################################################
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="<lambda>")
tfhers_bridge = tfhers.new_bridge(circuit, func_name="main")
# serialize key
_, key_path = tempfile.mkstemp()
@@ -780,7 +780,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen(
)
###### Concrete Keygen ########################################################
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main")
tfhers_bridge = tfhers.new_bridge(circuit, func_name="main")
with open(sk_path, "rb") as f:
sk_buff = f.read()
@@ -1044,7 +1044,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_tfhers_keygen(
)
###### Concrete Keygen ########################################################
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main")
tfhers_bridge = tfhers.new_bridge(circuit, func_name="main")
with open(sk_path, "rb") as f:
sk_buff = f.read()