diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index b6b04a1b1..1738141a5 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -12,18 +12,24 @@ from .dtypes import EncryptionKeyChoice, TFHERSIntegerType class Bridge: - """TFHErs Bridge extend a Circuit with TFHErs functionalities.""" + """TFHErs Bridge extend a Circuit with TFHErs functionalities. + + input_types (List[Optional[TFHERSIntegerType]]): maps every input to a type. None means + a non-tfhers type + output_types (List[Optional[TFHERSIntegerType]]): maps every output to a type. None means + a non-tfhers type + """ circuit: "fhe.Circuit" - input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]] - output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]] + input_types: List[Optional[TFHERSIntegerType]] + output_types: List[Optional[TFHERSIntegerType]] func_name: str def __init__( self, circuit: "fhe.Circuit", - input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]], - output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]], + input_types: List[Optional[TFHERSIntegerType]], + output_types: List[Optional[TFHERSIntegerType]], func_name: str, ): self.circuit = circuit @@ -40,9 +46,7 @@ class Bridge: Returns: Optional[TFHERSIntegerType]: input type. None means a non-tfhers type """ - if isinstance(self.input_types, list): - return self.input_types[input_idx] # pragma: no cover - return self.input_types + return self.input_types[input_idx] def _output_type(self, output_idx: int) -> Optional[TFHERSIntegerType]: """Return the type of a certain output. @@ -53,9 +57,7 @@ class Bridge: Returns: Optional[TFHERSIntegerType]: output type. None means a non-tfhers type """ - if isinstance(self.output_types, list): - return self.output_types[output_idx] # pragma: no cover - return self.output_types + return self.output_types[output_idx] def _input_keyid(self, input_idx: int) -> int: return self.circuit.client.specs.client_parameters.input_keyid_at(input_idx, self.func_name) @@ -230,23 +232,32 @@ class Bridge: def new_bridge( circuit: "fhe.Circuit", - input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]], - output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]], - func_name, + func_name: str = "main", ) -> Bridge: """Create a TFHErs bridge from a circuit. Args: circuit (Circuit): compiled circuit - input_types (Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]): lists - should map every input to a type, while a single element is general for all inputs. - None means a non-tfhers type - output_types (Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]): lists - should map every output to a type, while a single element is general for all outputs. - None means a non-tfhers type - func_name (str): name of the function to use. + func_name (str, optional): name of the function to use. Defaults to "main". Returns: Bridge: TFHErs bridge """ + input_types = [ + ( + input_node.output.dtype + if isinstance(input_node.output.dtype, TFHERSIntegerType) + else None + ) + for input_node in circuit.graph.ordered_inputs() + ] + output_types = [ + ( + output_node.output.dtype + if isinstance(output_node.output.dtype, TFHERSIntegerType) + else None + ) + for output_node in circuit.graph.ordered_outputs() + ] + return Bridge(circuit, input_types, output_types, func_name) diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index 676351921..1fdc01875 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -381,8 +381,8 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( concrete_encoded_result = circuit.encrypt_run_decrypt(*concrete_encoded_sample) assert (dtype.decode(concrete_encoded_result) == function(*sample)).all() - ###### TFHErs Encryption ###################################################### - tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="") + ###### TFHErs Encryption & Computation ######################################## + tfhers_bridge = tfhers.new_bridge(circuit, func_name="main") # serialize key _, key_path = tempfile.mkstemp() @@ -617,7 +617,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( assert (dtype.decode(concrete_encoded_result) == function(*sample)).all() ###### TFHErs Encryption ###################################################### - tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="") + tfhers_bridge = tfhers.new_bridge(circuit, func_name="main") # serialize key _, key_path = tempfile.mkstemp() @@ -780,7 +780,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( ) ###### Concrete Keygen ######################################################## - tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main") + tfhers_bridge = tfhers.new_bridge(circuit, func_name="main") with open(sk_path, "rb") as f: sk_buff = f.read() @@ -1044,7 +1044,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_tfhers_keygen( ) ###### Concrete Keygen ######################################################## - tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main") + tfhers_bridge = tfhers.new_bridge(circuit, func_name="main") with open(sk_path, "rb") as f: sk_buff = f.read()