diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index 3cdd68317..aff8f481f 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -39,6 +39,13 @@ class Circuit: assert_that(self.configuration.enable_unsafe_features) return + input_signs = [] + for i in range(len(graph.input_nodes)): # pylint: disable=consider-using-enumerate + input_value = graph.input_nodes[i].output + assert_that(isinstance(input_value.dtype, Integer)) + input_dtype = cast(Integer, input_value.dtype) + input_signs.append(input_dtype.is_signed) + output_signs = [] for i in range(len(graph.output_nodes)): # pylint: disable=consider-using-enumerate output_value = graph.output_nodes[i].output @@ -46,7 +53,7 @@ class Circuit: output_dtype = cast(Integer, output_value.dtype) output_signs.append(output_dtype.is_signed) - self.server = Server.create(mlir, output_signs, self.configuration) + self.server = Server.create(mlir, input_signs, output_signs, self.configuration) keyset_cache_directory = None if self.configuration.use_insecure_key_cache: diff --git a/concrete/numpy/compilation/server.py b/concrete/numpy/compilation/server.py index 740b51128..f9f7ca1cd 100644 --- a/concrete/numpy/compilation/server.py +++ b/concrete/numpy/compilation/server.py @@ -58,7 +58,12 @@ class Server: ) @staticmethod - def create(mlir: str, output_signs: List[bool], configuration: Configuration) -> "Server": + def create( + mlir: str, + input_signs: List[bool], + output_signs: List[bool], + configuration: Configuration, + ) -> "Server": """ Create a server using MLIR and output sign information. @@ -66,6 +71,9 @@ class Server: mlir (str): mlir to compile + input_signs (List[bool]): + sign status of the inputs + output_signs (List[bool]): sign status of the outputs @@ -101,7 +109,8 @@ class Server: compilation_result = support.compile(mlir, options) server_lambda = support.load_server_lambda(compilation_result) - client_specs = ClientSpecs(support.load_client_parameters(compilation_result), output_signs) + client_parameters = support.load_client_parameters(compilation_result) + client_specs = ClientSpecs(input_signs, client_parameters, output_signs) return Server(client_specs, output_dir, support, compilation_result, server_lambda) def save(self, path: Union[str, Path]): diff --git a/concrete/numpy/compilation/specs.py b/concrete/numpy/compilation/specs.py index 12c26a0bf..1d0c4f54d 100644 --- a/concrete/numpy/compilation/specs.py +++ b/concrete/numpy/compilation/specs.py @@ -13,10 +13,17 @@ class ClientSpecs: ClientSpecs class, to create Client objects. """ + input_signs: List[bool] client_parameters: ClientParameters output_signs: List[bool] - def __init__(self, client_parameters: ClientParameters, output_signs: List[bool]): + def __init__( + self, + input_signs: List[bool], + client_parameters: ClientParameters, + output_signs: List[bool], + ): + self.input_signs = input_signs self.client_parameters = client_parameters self.output_signs = output_signs @@ -32,6 +39,7 @@ class ClientSpecs: client_parameters_json = json.loads(self.client_parameters.serialize()) return json.dumps( { + "input_signs": self.input_signs, "client_parameters": client_parameters_json, "output_signs": self.output_signs, } @@ -56,7 +64,7 @@ class ClientSpecs: client_parameters_bytes = json.dumps(raw_specs["client_parameters"]).encode("utf-8") client_parameters = ClientParameters.unserialize(client_parameters_bytes) - return ClientSpecs(client_parameters, raw_specs["output_signs"]) + return ClientSpecs(raw_specs["input_signs"], client_parameters, raw_specs["output_signs"]) def serialize_public_args(self, args: PublicArguments) -> bytes: # pylint: disable=no-self-use """