mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add input signs to client specs
This commit is contained in:
@@ -39,6 +39,13 @@ class Circuit:
|
||||
assert_that(self.configuration.enable_unsafe_features)
|
||||
return
|
||||
|
||||
input_signs = []
|
||||
for i in range(len(graph.input_nodes)): # pylint: disable=consider-using-enumerate
|
||||
input_value = graph.input_nodes[i].output
|
||||
assert_that(isinstance(input_value.dtype, Integer))
|
||||
input_dtype = cast(Integer, input_value.dtype)
|
||||
input_signs.append(input_dtype.is_signed)
|
||||
|
||||
output_signs = []
|
||||
for i in range(len(graph.output_nodes)): # pylint: disable=consider-using-enumerate
|
||||
output_value = graph.output_nodes[i].output
|
||||
@@ -46,7 +53,7 @@ class Circuit:
|
||||
output_dtype = cast(Integer, output_value.dtype)
|
||||
output_signs.append(output_dtype.is_signed)
|
||||
|
||||
self.server = Server.create(mlir, output_signs, self.configuration)
|
||||
self.server = Server.create(mlir, input_signs, output_signs, self.configuration)
|
||||
|
||||
keyset_cache_directory = None
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
|
||||
@@ -58,7 +58,12 @@ class Server:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(mlir: str, output_signs: List[bool], configuration: Configuration) -> "Server":
|
||||
def create(
|
||||
mlir: str,
|
||||
input_signs: List[bool],
|
||||
output_signs: List[bool],
|
||||
configuration: Configuration,
|
||||
) -> "Server":
|
||||
"""
|
||||
Create a server using MLIR and output sign information.
|
||||
|
||||
@@ -66,6 +71,9 @@ class Server:
|
||||
mlir (str):
|
||||
mlir to compile
|
||||
|
||||
input_signs (List[bool]):
|
||||
sign status of the inputs
|
||||
|
||||
output_signs (List[bool]):
|
||||
sign status of the outputs
|
||||
|
||||
@@ -101,7 +109,8 @@ class Server:
|
||||
compilation_result = support.compile(mlir, options)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
client_specs = ClientSpecs(support.load_client_parameters(compilation_result), output_signs)
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
client_specs = ClientSpecs(input_signs, client_parameters, output_signs)
|
||||
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
|
||||
def save(self, path: Union[str, Path]):
|
||||
|
||||
@@ -13,10 +13,17 @@ class ClientSpecs:
|
||||
ClientSpecs class, to create Client objects.
|
||||
"""
|
||||
|
||||
input_signs: List[bool]
|
||||
client_parameters: ClientParameters
|
||||
output_signs: List[bool]
|
||||
|
||||
def __init__(self, client_parameters: ClientParameters, output_signs: List[bool]):
|
||||
def __init__(
|
||||
self,
|
||||
input_signs: List[bool],
|
||||
client_parameters: ClientParameters,
|
||||
output_signs: List[bool],
|
||||
):
|
||||
self.input_signs = input_signs
|
||||
self.client_parameters = client_parameters
|
||||
self.output_signs = output_signs
|
||||
|
||||
@@ -32,6 +39,7 @@ class ClientSpecs:
|
||||
client_parameters_json = json.loads(self.client_parameters.serialize())
|
||||
return json.dumps(
|
||||
{
|
||||
"input_signs": self.input_signs,
|
||||
"client_parameters": client_parameters_json,
|
||||
"output_signs": self.output_signs,
|
||||
}
|
||||
@@ -56,7 +64,7 @@ class ClientSpecs:
|
||||
client_parameters_bytes = json.dumps(raw_specs["client_parameters"]).encode("utf-8")
|
||||
client_parameters = ClientParameters.unserialize(client_parameters_bytes)
|
||||
|
||||
return ClientSpecs(client_parameters, raw_specs["output_signs"])
|
||||
return ClientSpecs(raw_specs["input_signs"], client_parameters, raw_specs["output_signs"])
|
||||
|
||||
def serialize_public_args(self, args: PublicArguments) -> bytes: # pylint: disable=no-self-use
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user