mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(frontend-python): support ctrl+c during fhe execution
This commit is contained in:
@@ -34,6 +34,7 @@ from .configuration import (
|
||||
ParameterSelectionStrategy,
|
||||
)
|
||||
from .specs import ClientSpecs
|
||||
from .utils import interruptable_native_call
|
||||
from .value import Value
|
||||
|
||||
# pylint: enable=import-error,no-member,no-name-in-module
|
||||
@@ -298,7 +299,9 @@ class Server:
|
||||
buffers.append(arg.inner)
|
||||
|
||||
public_args = PublicArguments.new(self.client_specs.client_parameters, buffers)
|
||||
public_result = self._support.server_call(self._server_lambda, public_args, evaluation_keys)
|
||||
public_result = interruptable_native_call(
|
||||
lambda: self._support.server_call(self._server_lambda, public_args, evaluation_keys)
|
||||
)
|
||||
|
||||
result = tuple(Value(public_result.get_value(i)) for i in range(public_result.n_values()))
|
||||
return result if len(result) > 1 else result[0]
|
||||
|
||||
@@ -2,7 +2,12 @@
|
||||
Declaration of various functions and constants related to compilation.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
@@ -15,6 +20,40 @@ from .artifacts import DebugArtifacts
|
||||
# ruff: noqa: ERA001
|
||||
|
||||
|
||||
def print_ctrl_c_message(): # pragma: no cover
|
||||
"""
|
||||
Print exit message for CTRL+C.
|
||||
"""
|
||||
|
||||
print()
|
||||
print("The computation will be aborted in a few seconds.")
|
||||
print("You can force an immediate abort by pressing Ctrl+C again.")
|
||||
|
||||
|
||||
def interruptable_native_call(f):
|
||||
"""
|
||||
Run a native function `f` in a thread to support interrupts.
|
||||
|
||||
Note that `f` must release the GIL to make it work.
|
||||
"""
|
||||
|
||||
atexit.register(print_ctrl_c_message)
|
||||
|
||||
executor = ThreadPoolExecutor(1)
|
||||
try:
|
||||
f_thread = executor.submit(f)
|
||||
return f_thread.result()
|
||||
except KeyboardInterrupt as error: # pragma: no cover
|
||||
pid = os.getpid()
|
||||
|
||||
def wait_and_ctrl_c():
|
||||
time.sleep(0.2) # wait so atexit can progress to the executor
|
||||
os.kill(pid, signal.SIGINT)
|
||||
|
||||
ThreadPoolExecutor(1).submit(wait_and_ctrl_c)
|
||||
raise error
|
||||
|
||||
|
||||
def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
|
||||
"""
|
||||
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
|
||||
|
||||
Reference in New Issue
Block a user