Merge pull request #1144 from zama-ai/module-auto-schedule-rebased

feat(frontend-python): module run are scheduled and parallelized in a…
This commit is contained in:
Quentin Bourgerie
2024-12-06 17:35:34 +01:00
committed by GitHub
5 changed files with 237 additions and 10 deletions

View File

@@ -210,3 +210,11 @@ When options are specified both in the `configuration` and as kwargs in the `com
#### verbose: bool = False
- Print details related to compilation.
#### auto_schedule_run: bool = False
- Enable automatic scheduling of `run` method calls. When enabled, fhe function are computated in parallel in a background threads pool. When several `run` are composed, they are automatically synchronized.
- For now, it only works for the `run` method of a `FheModule`, in that case you obtain a `Future[Value]` immediately instead of a `Value` when computation is finished.
- E.g. `my_module.f3.run( my_module.f1.run(a), my_module.f1.run(b) )` will runs `f1` and `f2` in parallel in the background and `f3` in background when both `f1` and `f2` intermediate results are available.
- If you want to manually synchronize on the termination of a full computation, e.g. you want to return the encrypted result, you can call explicitely `value.result()` to wait for the result. To simplify testing, decryption does it automatically.
- Automatic scheduling behavior can be override locally by calling directly a variant of `run`:
- `run_sync`: forces the fhe function to occur in the current thread, not in the background,
- `run_async`: forces the fhe function to occur in a background thread, returning immediately a `Future[Value]`

View File

@@ -195,7 +195,7 @@ class Circuit:
result(s) of evaluation
"""
return self._function.run(*args)
return self._function.run_sync(*args)
def decrypt(
self,

View File

@@ -997,6 +997,7 @@ class Configuration:
composable: bool
range_restriction: Optional[RangeRestriction]
keyset_restriction: Optional[KeysetRestriction]
auto_schedule_run: bool
def __init__(
self,
@@ -1068,6 +1069,7 @@ class Configuration:
simulate_encrypt_run_decrypt: bool = False,
range_restriction: Optional[RangeRestriction] = None,
keyset_restriction: Optional[KeysetRestriction] = None,
auto_schedule_run: bool = False,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
@@ -1177,6 +1179,8 @@ class Configuration:
self.range_restriction = range_restriction
self.keyset_restriction = keyset_restriction
self.auto_schedule_run = auto_schedule_run
self._validate()
class Keep:
@@ -1254,6 +1258,7 @@ class Configuration:
simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP,
range_restriction: Union[Keep, Optional[RangeRestriction]] = KEEP,
keyset_restriction: Union[Keep, Optional[KeysetRestriction]] = KEEP,
auto_schedule_run: Union[Keep, bool] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.

View File

@@ -4,8 +4,11 @@ Declaration of `FheModule` classes.
# pylint: disable=import-error,no-member,no-name-in-module
import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
from threading import Thread
from typing import Any, Awaitable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
import numpy as np
from concrete.compiler import CompilationContext, LweSecretKey, Parameter
@@ -24,13 +27,40 @@ from .value import Value
# pylint: enable=import-error,no-member,no-name-in-module
class ExecutionRt(NamedTuple):
class ExecutionRt:
"""
Runtime object class for execution.
"""
client: Client
server: Server
auto_schedule_run: bool
fhe_executor_pool: ThreadPoolExecutor
fhe_waiter_loop: asyncio.BaseEventLoop
fhe_waiter_thread: Thread # daemon thread
def __init__(self, client, server, auto_schedule_run):
self.client = client
self.server = server
self.auto_schedule_run = auto_schedule_run
if auto_schedule_run:
self.fhe_executor_pool = ThreadPoolExecutor()
self.fhe_waiter_loop = asyncio.new_event_loop()
def loop_thread():
asyncio.set_event_loop(self.fhe_waiter_loop)
self.fhe_waiter_loop.run_forever()
self.fhe_waiter_thread = Thread(target=loop_thread, args=(), daemon=True)
self.fhe_waiter_thread.start()
else:
self.fhe_executor_pool = None
self.fhe_waiter_loop = None
self.fhe_waiter_thread = None
def __del__(self):
if self.fhe_waiter_loop:
self.fhe_waiter_loop.stop() # pragma: no cover
class SimulationRt(NamedTuple):
@@ -177,12 +207,12 @@ class FheFunction:
return tuple(args) if len(args) > 1 else args[0] # type: ignore
return self.execution_runtime.val.client.encrypt(*args, function_name=self.name)
def run(
def run_sync(
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...]]:
) -> Any:
"""
Evaluate the function.
Evaluate the function synchronuously.
Args:
*args (Value):
@@ -193,17 +223,115 @@ class FheFunction:
result(s) of evaluation
"""
return self._run(True, *args)
def run_async(
self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]]
) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
"""
Evaluate the function asynchronuously.
Args:
*args (Value):
argument(s) for evaluation
Returns:
Union[Awaitable[Value], Awaitable[Tuple[Value, ...]]]:
result(s) a future of the evaluation
"""
if (
isinstance(self.execution_runtime.val, ExecutionRt)
and not self.execution_runtime.val.fhe_executor_pool
):
client = self.execution_runtime.val.client
server = self.execution_runtime.val.server
self.execution_runtime = Lazy(lambda: ExecutionRt(client, server, True))
self.execution_runtime.val.auto_schedule_run = False
return self._run(False, *args)
def run(
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
"""
Evaluate the function.
Args:
*args (Value):
argument(s) for evaluation
Returns:
Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
result(s) of evaluation or future of result(s) of evaluation if configured with async_run=True
"""
if isinstance(self.execution_runtime.val, ExecutionRt):
auto_schedule_run = self.execution_runtime.val.auto_schedule_run
else:
auto_schedule_run = False # pragma: no cover
return self._run(not auto_schedule_run, *args)
def _run(
self,
sync: bool,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
"""
Evaluate the function.
Args:
*args (Value):
argument(s) for evaluation
Returns:
Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
result(s) of evaluation if sync=True else future of result(s) of evaluation
"""
if self.configuration.simulate_encrypt_run_decrypt:
return self._simulate_decrypt(self._simulate_run(*args)) # type: ignore
return self.execution_runtime.val.server.run(
assert isinstance(self.execution_runtime.val, ExecutionRt)
fhe_work = lambda *args: self.execution_runtime.val.server.run(
*args,
evaluation_keys=self.execution_runtime.val.client.evaluation_keys,
function_name=self.name,
)
def args_ready(args):
return [arg.result() if isinstance(arg, Future) else arg for arg in args]
if sync:
return fhe_work(*args_ready(args))
all_args_done = all(not isinstance(arg, Future) or arg.done() for arg in args)
fhe_work_future = lambda *args: self.execution_runtime.val.fhe_executor_pool.submit(
fhe_work, *args
)
if all_args_done:
return fhe_work_future(*args_ready(args)) # type: ignore
# waiting args to be ready with async coroutines
# it only required one thread to run unlimited waits vs unlimited sync threads
async def wait_async(arg):
if not isinstance(arg, Future):
return arg # pragma: no cover
if arg.done():
return arg.result() # pragma: no cover
return await asyncio.wrap_future(arg, loop=self.execution_runtime.val.fhe_waiter_loop)
async def args_ready_and_submit(*args):
args = [await wait_async(arg) for arg in args]
return await wait_async(fhe_work_future(*args))
run_async = args_ready_and_submit(*args)
return asyncio.run_coroutine_threadsafe(
run_async, self.execution_runtime.val.fhe_waiter_loop
) # type: ignore
def decrypt(
self,
*results: Union[Value, Tuple[Value, ...]],
self, *results: Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]
) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
"""
Decrypt result(s) of evaluation.
@@ -220,6 +348,8 @@ class FheFunction:
if self.configuration.simulate_encrypt_run_decrypt:
return tuple(results) if len(results) > 1 else results[0] # type: ignore
assert isinstance(self.execution_runtime.val, ExecutionRt)
results = [res.result() if isinstance(res, Future) else res for res in results]
return self.execution_runtime.val.client.decrypt(*results, function_name=self.name)
def encrypt_run_decrypt(self, *args: Any) -> Any:
@@ -620,7 +750,9 @@ class FheModule:
execution_client = Client(
execution_server.client_specs, keyset_cache_directory, is_simulated=False
)
return ExecutionRt(execution_client, execution_server)
return ExecutionRt(
execution_client, execution_server, self.configuration.auto_schedule_run
)
self.execution_runtime = Lazy(init_execution)
if configuration.fhe_execution:

View File

@@ -4,7 +4,9 @@ Tests of everything related to modules.
import inspect
import tempfile
from concurrent.futures import Future
from pathlib import Path
from typing import Awaitable
import numpy as np
import pytest
@@ -955,3 +957,83 @@ def test_wired_with_all_encrypted_inputs(helpers):
},
helpers.configuration().fork(),
)
class IncDec:
@fhe.module()
class Module:
@fhe.function({"x": "encrypted"})
def inc(x):
return fhe.refresh(x + 1)
@fhe.function({"x": "encrypted"})
def dec(x):
return fhe.refresh(x - 1)
precision = 4
inputset = list(range(1, 2**precision - 1))
to_compile = {"inc": inputset, "dec": inputset}
def test_run_async():
"""
Test `run_async` with `auto_schedule_run=False` configuration option.
"""
module = IncDec.Module.compile(IncDec.to_compile)
sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)
a = module.inc.run_async(encrypted_x)
assert isinstance(a, Future)
b = module.dec.run(a)
assert isinstance(b, type(encrypted_x))
result = module.inc.decrypt(b)
assert result == sample_x
del module
def test_run_sync():
"""
Test `run_sync` with `auto_schedule_run=True` configuration option.
"""
conf = fhe.Configuration(auto_schedule_run=True)
module = IncDec.Module.compile(IncDec.to_compile, conf)
sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)
a = module.inc.run(encrypted_x)
assert isinstance(a, Future)
b = module.dec.run_sync(a)
assert isinstance(b, type(encrypted_x))
result = module.inc.decrypt(b)
assert result == sample_x
def test_run_auto_schedule():
"""
Test `run` with `auto_schedule_run=True` configuration option.
"""
conf = fhe.Configuration(auto_schedule_run=True)
module = IncDec.Module.compile(IncDec.to_compile, conf)
sample_x = 2
encrypted_x = module.inc.encrypt(sample_x)
a = module.inc.run(encrypted_x)
assert isinstance(a, Future)
b = module.dec.run(a)
assert isinstance(b, Future)
result = module.inc.decrypt(b)
assert result == sample_x