feat: add input signs to client specs

This commit is contained in:
Umut
2022-05-30 13:00:19 +02:00
parent 8a60a979cb
commit 4010fc0cbd
3 changed files with 29 additions and 5 deletions

View File

@@ -39,6 +39,13 @@ class Circuit:
assert_that(self.configuration.enable_unsafe_features)
return
input_signs = []
for i in range(len(graph.input_nodes)): # pylint: disable=consider-using-enumerate
input_value = graph.input_nodes[i].output
assert_that(isinstance(input_value.dtype, Integer))
input_dtype = cast(Integer, input_value.dtype)
input_signs.append(input_dtype.is_signed)
output_signs = []
for i in range(len(graph.output_nodes)): # pylint: disable=consider-using-enumerate
output_value = graph.output_nodes[i].output
@@ -46,7 +53,7 @@ class Circuit:
output_dtype = cast(Integer, output_value.dtype)
output_signs.append(output_dtype.is_signed)
self.server = Server.create(mlir, output_signs, self.configuration)
self.server = Server.create(mlir, input_signs, output_signs, self.configuration)
keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:

View File

@@ -58,7 +58,12 @@ class Server:
)
@staticmethod
def create(mlir: str, output_signs: List[bool], configuration: Configuration) -> "Server":
def create(
mlir: str,
input_signs: List[bool],
output_signs: List[bool],
configuration: Configuration,
) -> "Server":
"""
Create a server using MLIR and output sign information.
@@ -66,6 +71,9 @@ class Server:
mlir (str):
mlir to compile
input_signs (List[bool]):
sign status of the inputs
output_signs (List[bool]):
sign status of the outputs
@@ -101,7 +109,8 @@ class Server:
compilation_result = support.compile(mlir, options)
server_lambda = support.load_server_lambda(compilation_result)
client_specs = ClientSpecs(support.load_client_parameters(compilation_result), output_signs)
client_parameters = support.load_client_parameters(compilation_result)
client_specs = ClientSpecs(input_signs, client_parameters, output_signs)
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
def save(self, path: Union[str, Path]):

View File

@@ -13,10 +13,17 @@ class ClientSpecs:
ClientSpecs class, to create Client objects.
"""
input_signs: List[bool]
client_parameters: ClientParameters
output_signs: List[bool]
def __init__(self, client_parameters: ClientParameters, output_signs: List[bool]):
def __init__(
self,
input_signs: List[bool],
client_parameters: ClientParameters,
output_signs: List[bool],
):
self.input_signs = input_signs
self.client_parameters = client_parameters
self.output_signs = output_signs
@@ -32,6 +39,7 @@ class ClientSpecs:
client_parameters_json = json.loads(self.client_parameters.serialize())
return json.dumps(
{
"input_signs": self.input_signs,
"client_parameters": client_parameters_json,
"output_signs": self.output_signs,
}
@@ -56,7 +64,7 @@ class ClientSpecs:
client_parameters_bytes = json.dumps(raw_specs["client_parameters"]).encode("utf-8")
client_parameters = ClientParameters.unserialize(client_parameters_bytes)
return ClientSpecs(client_parameters, raw_specs["output_signs"])
return ClientSpecs(raw_specs["input_signs"], client_parameters, raw_specs["output_signs"])
def serialize_public_args(self, args: PublicArguments) -> bytes: # pylint: disable=no-self-use
"""