feat(frontend-python): support ctrl+c during fhe execution

This commit is contained in:
Umut
2023-06-26 10:17:40 +02:00
parent aff3d91278
commit 87d460e9ec
2 changed files with 43 additions and 1 deletions

View File

@@ -34,6 +34,7 @@ from .configuration import (
ParameterSelectionStrategy,
)
from .specs import ClientSpecs
from .utils import interruptable_native_call
from .value import Value
# pylint: enable=import-error,no-member,no-name-in-module
@@ -298,7 +299,9 @@ class Server:
buffers.append(arg.inner)
public_args = PublicArguments.new(self.client_specs.client_parameters, buffers)
public_result = self._support.server_call(self._server_lambda, public_args, evaluation_keys)
public_result = interruptable_native_call(
lambda: self._support.server_call(self._server_lambda, public_args, evaluation_keys)
)
result = tuple(Value(public_result.get_value(i)) for i in range(public_result.n_values()))
return result if len(result) > 1 else result[0]

View File

@@ -2,7 +2,12 @@
Declaration of various functions and constants related to compilation.
"""
import atexit
import os
import re
import signal
import time
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Set, Tuple
@@ -15,6 +20,40 @@ from .artifacts import DebugArtifacts
# ruff: noqa: ERA001
def print_ctrl_c_message(): # pragma: no cover
"""
Print exit message for CTRL+C.
"""
print()
print("The computation will be aborted in a few seconds.")
print("You can force an immediate abort by pressing Ctrl+C again.")
def interruptable_native_call(f):
"""
Run a native function `f` in a thread to support interrupts.
Note that `f` must release the GIL to make it work.
"""
atexit.register(print_ctrl_c_message)
executor = ThreadPoolExecutor(1)
try:
f_thread = executor.submit(f)
return f_thread.result()
except KeyboardInterrupt as error: # pragma: no cover
pid = os.getpid()
def wait_and_ctrl_c():
time.sleep(0.2) # wait so atexit can progress to the executor
os.kill(pid, signal.SIGINT)
ThreadPoolExecutor(1).submit(wait_and_ctrl_c)
raise error
def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
"""
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.