refactor(frontend): use func_name from circuit

This commit is contained in:
youben11
2024-09-26 08:45:54 +01:00
committed by Quentin Bourgerie
parent 11bf8d9066
commit 2d341aaf81
2 changed files with 13 additions and 11 deletions

View File

@@ -39,6 +39,10 @@ class Circuit:
def _function(self) -> FheFunction:
return getattr(self._module, self._name)
@property
def function_name(self) -> str:
return self._name
def __str__(self):
return self._function.graph.format()
@@ -148,7 +152,10 @@ class Circuit:
initial keys to set before keygen
"""
self._module.keygen(
force=force, seed=seed, encryption_seed=encryption_seed, initial_keys=initial_keys
force=force,
seed=seed,
encryption_seed=encryption_seed,
initial_keys=initial_keys,
)
def encrypt(

View File

@@ -23,19 +23,16 @@ class Bridge:
circuit: "fhe.Circuit"
input_types: List[Optional[TFHERSIntegerType]]
output_types: List[Optional[TFHERSIntegerType]]
func_name: str
def __init__(
self,
circuit: "fhe.Circuit",
input_types: List[Optional[TFHERSIntegerType]],
output_types: List[Optional[TFHERSIntegerType]],
func_name: str,
):
self.circuit = circuit
self.input_types = input_types
self.output_types = output_types
self.func_name = func_name
def _input_type(self, input_idx: int) -> Optional[TFHERSIntegerType]:
"""Return the type of a certain input.
@@ -60,7 +57,9 @@ class Bridge:
return self.output_types[output_idx]
def _input_keyid(self, input_idx: int) -> int:
return self.circuit.client.specs.client_parameters.input_keyid_at(input_idx, self.func_name)
return self.circuit.client.specs.client_parameters.input_keyid_at(
input_idx, self.circuit.function_name
)
def _input_variance(self, input_idx: int) -> float:
input_type = self._input_type(input_idx)
@@ -230,15 +229,11 @@ class Bridge:
)
def new_bridge(
circuit: "fhe.Circuit",
func_name: str = "<lambda>",
) -> Bridge:
def new_bridge(circuit: "fhe.Circuit") -> Bridge:
"""Create a TFHErs bridge from a circuit.
Args:
circuit (Circuit): compiled circuit
func_name (str, optional): name of the function to use. Defaults to "main".
Returns:
Bridge: TFHErs bridge
@@ -260,4 +255,4 @@ def new_bridge(
for output_node in circuit.graph.ordered_outputs()
]
return Bridge(circuit, input_types, output_types, func_name)
return Bridge(circuit, input_types, output_types)