mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-09 21:07:56 -05:00
refactor(frontend): remove the need to pass the in/out types in Bridge
This commit is contained in:
committed by
Quentin Bourgerie
parent
eb72bbc53c
commit
d63ff516af
@@ -12,18 +12,24 @@ from .dtypes import EncryptionKeyChoice, TFHERSIntegerType
|
||||
|
||||
|
||||
class Bridge:
|
||||
"""TFHErs Bridge extend a Circuit with TFHErs functionalities."""
|
||||
"""TFHErs Bridge extend a Circuit with TFHErs functionalities.
|
||||
|
||||
input_types (List[Optional[TFHERSIntegerType]]): maps every input to a type. None means
|
||||
a non-tfhers type
|
||||
output_types (List[Optional[TFHERSIntegerType]]): maps every output to a type. None means
|
||||
a non-tfhers type
|
||||
"""
|
||||
|
||||
circuit: "fhe.Circuit"
|
||||
input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]
|
||||
output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]
|
||||
input_types: List[Optional[TFHERSIntegerType]]
|
||||
output_types: List[Optional[TFHERSIntegerType]]
|
||||
func_name: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
circuit: "fhe.Circuit",
|
||||
input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
|
||||
output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
|
||||
input_types: List[Optional[TFHERSIntegerType]],
|
||||
output_types: List[Optional[TFHERSIntegerType]],
|
||||
func_name: str,
|
||||
):
|
||||
self.circuit = circuit
|
||||
@@ -40,9 +46,7 @@ class Bridge:
|
||||
Returns:
|
||||
Optional[TFHERSIntegerType]: input type. None means a non-tfhers type
|
||||
"""
|
||||
if isinstance(self.input_types, list):
|
||||
return self.input_types[input_idx] # pragma: no cover
|
||||
return self.input_types
|
||||
return self.input_types[input_idx]
|
||||
|
||||
def _output_type(self, output_idx: int) -> Optional[TFHERSIntegerType]:
|
||||
"""Return the type of a certain output.
|
||||
@@ -53,9 +57,7 @@ class Bridge:
|
||||
Returns:
|
||||
Optional[TFHERSIntegerType]: output type. None means a non-tfhers type
|
||||
"""
|
||||
if isinstance(self.output_types, list):
|
||||
return self.output_types[output_idx] # pragma: no cover
|
||||
return self.output_types
|
||||
return self.output_types[output_idx]
|
||||
|
||||
def _input_keyid(self, input_idx: int) -> int:
|
||||
return self.circuit.client.specs.client_parameters.input_keyid_at(input_idx, self.func_name)
|
||||
@@ -230,23 +232,32 @@ class Bridge:
|
||||
|
||||
def new_bridge(
|
||||
circuit: "fhe.Circuit",
|
||||
input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
|
||||
output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
|
||||
func_name,
|
||||
func_name: str = "main",
|
||||
) -> Bridge:
|
||||
"""Create a TFHErs bridge from a circuit.
|
||||
|
||||
Args:
|
||||
circuit (Circuit): compiled circuit
|
||||
input_types (Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]): lists
|
||||
should map every input to a type, while a single element is general for all inputs.
|
||||
None means a non-tfhers type
|
||||
output_types (Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]): lists
|
||||
should map every output to a type, while a single element is general for all outputs.
|
||||
None means a non-tfhers type
|
||||
func_name (str): name of the function to use.
|
||||
func_name (str, optional): name of the function to use. Defaults to "main".
|
||||
|
||||
Returns:
|
||||
Bridge: TFHErs bridge
|
||||
"""
|
||||
input_types = [
|
||||
(
|
||||
input_node.output.dtype
|
||||
if isinstance(input_node.output.dtype, TFHERSIntegerType)
|
||||
else None
|
||||
)
|
||||
for input_node in circuit.graph.ordered_inputs()
|
||||
]
|
||||
output_types = [
|
||||
(
|
||||
output_node.output.dtype
|
||||
if isinstance(output_node.output.dtype, TFHERSIntegerType)
|
||||
else None
|
||||
)
|
||||
for output_node in circuit.graph.ordered_outputs()
|
||||
]
|
||||
|
||||
return Bridge(circuit, input_types, output_types, func_name)
|
||||
|
||||
@@ -381,8 +381,8 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen(
|
||||
concrete_encoded_result = circuit.encrypt_run_decrypt(*concrete_encoded_sample)
|
||||
assert (dtype.decode(concrete_encoded_result) == function(*sample)).all()
|
||||
|
||||
###### TFHErs Encryption ######################################################
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="<lambda>")
|
||||
###### TFHErs Encryption & Computation ########################################
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, func_name="main")
|
||||
|
||||
# serialize key
|
||||
_, key_path = tempfile.mkstemp()
|
||||
@@ -617,7 +617,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen(
|
||||
assert (dtype.decode(concrete_encoded_result) == function(*sample)).all()
|
||||
|
||||
###### TFHErs Encryption ######################################################
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="<lambda>")
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, func_name="main")
|
||||
|
||||
# serialize key
|
||||
_, key_path = tempfile.mkstemp()
|
||||
@@ -780,7 +780,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen(
|
||||
)
|
||||
|
||||
###### Concrete Keygen ########################################################
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main")
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, func_name="main")
|
||||
|
||||
with open(sk_path, "rb") as f:
|
||||
sk_buff = f.read()
|
||||
@@ -1044,7 +1044,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_tfhers_keygen(
|
||||
)
|
||||
|
||||
###### Concrete Keygen ########################################################
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main")
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, func_name="main")
|
||||
|
||||
with open(sk_path, "rb") as f:
|
||||
sk_buff = f.read()
|
||||
|
||||
Reference in New Issue
Block a user