mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-09 12:57:55 -05:00
Merge pull request #1144 from zama-ai/module-auto-schedule-rebased
feat(frontend-python): module run are scheduled and parallelized in a…
This commit is contained in:
@@ -210,3 +210,11 @@ When options are specified both in the `configuration` and as kwargs in the `com
|
||||
#### verbose: bool = False
|
||||
- Print details related to compilation.
|
||||
|
||||
#### auto_schedule_run: bool = False
|
||||
- Enable automatic scheduling of `run` method calls. When enabled, fhe function are computated in parallel in a background threads pool. When several `run` are composed, they are automatically synchronized.
|
||||
- For now, it only works for the `run` method of a `FheModule`, in that case you obtain a `Future[Value]` immediately instead of a `Value` when computation is finished.
|
||||
- E.g. `my_module.f3.run( my_module.f1.run(a), my_module.f1.run(b) )` will runs `f1` and `f2` in parallel in the background and `f3` in background when both `f1` and `f2` intermediate results are available.
|
||||
- If you want to manually synchronize on the termination of a full computation, e.g. you want to return the encrypted result, you can call explicitely `value.result()` to wait for the result. To simplify testing, decryption does it automatically.
|
||||
- Automatic scheduling behavior can be override locally by calling directly a variant of `run`:
|
||||
- `run_sync`: forces the fhe function to occur in the current thread, not in the background,
|
||||
- `run_async`: forces the fhe function to occur in a background thread, returning immediately a `Future[Value]`
|
||||
|
||||
@@ -195,7 +195,7 @@ class Circuit:
|
||||
result(s) of evaluation
|
||||
"""
|
||||
|
||||
return self._function.run(*args)
|
||||
return self._function.run_sync(*args)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
|
||||
@@ -997,6 +997,7 @@ class Configuration:
|
||||
composable: bool
|
||||
range_restriction: Optional[RangeRestriction]
|
||||
keyset_restriction: Optional[KeysetRestriction]
|
||||
auto_schedule_run: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1068,6 +1069,7 @@ class Configuration:
|
||||
simulate_encrypt_run_decrypt: bool = False,
|
||||
range_restriction: Optional[RangeRestriction] = None,
|
||||
keyset_restriction: Optional[KeysetRestriction] = None,
|
||||
auto_schedule_run: bool = False,
|
||||
):
|
||||
self.verbose = verbose
|
||||
self.compiler_debug_mode = compiler_debug_mode
|
||||
@@ -1177,6 +1179,8 @@ class Configuration:
|
||||
self.range_restriction = range_restriction
|
||||
self.keyset_restriction = keyset_restriction
|
||||
|
||||
self.auto_schedule_run = auto_schedule_run
|
||||
|
||||
self._validate()
|
||||
|
||||
class Keep:
|
||||
@@ -1254,6 +1258,7 @@ class Configuration:
|
||||
simulate_encrypt_run_decrypt: Union[Keep, bool] = KEEP,
|
||||
range_restriction: Union[Keep, Optional[RangeRestriction]] = KEEP,
|
||||
keyset_restriction: Union[Keep, Optional[KeysetRestriction]] = KEEP,
|
||||
auto_schedule_run: Union[Keep, bool] = KEEP,
|
||||
) -> "Configuration":
|
||||
"""
|
||||
Get a new configuration from another one specified changes.
|
||||
|
||||
@@ -4,8 +4,11 @@ Declaration of `FheModule` classes.
|
||||
|
||||
# pylint: disable=import-error,no-member,no-name-in-module
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
|
||||
from threading import Thread
|
||||
from typing import Any, Awaitable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilationContext, LweSecretKey, Parameter
|
||||
@@ -24,13 +27,40 @@ from .value import Value
|
||||
# pylint: enable=import-error,no-member,no-name-in-module
|
||||
|
||||
|
||||
class ExecutionRt(NamedTuple):
|
||||
class ExecutionRt:
|
||||
"""
|
||||
Runtime object class for execution.
|
||||
"""
|
||||
|
||||
client: Client
|
||||
server: Server
|
||||
auto_schedule_run: bool
|
||||
fhe_executor_pool: ThreadPoolExecutor
|
||||
fhe_waiter_loop: asyncio.BaseEventLoop
|
||||
fhe_waiter_thread: Thread # daemon thread
|
||||
|
||||
def __init__(self, client, server, auto_schedule_run):
|
||||
self.client = client
|
||||
self.server = server
|
||||
self.auto_schedule_run = auto_schedule_run
|
||||
if auto_schedule_run:
|
||||
self.fhe_executor_pool = ThreadPoolExecutor()
|
||||
self.fhe_waiter_loop = asyncio.new_event_loop()
|
||||
|
||||
def loop_thread():
|
||||
asyncio.set_event_loop(self.fhe_waiter_loop)
|
||||
self.fhe_waiter_loop.run_forever()
|
||||
|
||||
self.fhe_waiter_thread = Thread(target=loop_thread, args=(), daemon=True)
|
||||
self.fhe_waiter_thread.start()
|
||||
else:
|
||||
self.fhe_executor_pool = None
|
||||
self.fhe_waiter_loop = None
|
||||
self.fhe_waiter_thread = None
|
||||
|
||||
def __del__(self):
|
||||
if self.fhe_waiter_loop:
|
||||
self.fhe_waiter_loop.stop() # pragma: no cover
|
||||
|
||||
|
||||
class SimulationRt(NamedTuple):
|
||||
@@ -177,12 +207,12 @@ class FheFunction:
|
||||
return tuple(args) if len(args) > 1 else args[0] # type: ignore
|
||||
return self.execution_runtime.val.client.encrypt(*args, function_name=self.name)
|
||||
|
||||
def run(
|
||||
def run_sync(
|
||||
self,
|
||||
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
|
||||
) -> Union[Value, Tuple[Value, ...]]:
|
||||
) -> Any:
|
||||
"""
|
||||
Evaluate the function.
|
||||
Evaluate the function synchronuously.
|
||||
|
||||
Args:
|
||||
*args (Value):
|
||||
@@ -193,17 +223,115 @@ class FheFunction:
|
||||
result(s) of evaluation
|
||||
"""
|
||||
|
||||
return self._run(True, *args)
|
||||
|
||||
def run_async(
|
||||
self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]]
|
||||
) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
|
||||
"""
|
||||
Evaluate the function asynchronuously.
|
||||
|
||||
Args:
|
||||
*args (Value):
|
||||
argument(s) for evaluation
|
||||
|
||||
Returns:
|
||||
Union[Awaitable[Value], Awaitable[Tuple[Value, ...]]]:
|
||||
result(s) a future of the evaluation
|
||||
"""
|
||||
if (
|
||||
isinstance(self.execution_runtime.val, ExecutionRt)
|
||||
and not self.execution_runtime.val.fhe_executor_pool
|
||||
):
|
||||
client = self.execution_runtime.val.client
|
||||
server = self.execution_runtime.val.server
|
||||
self.execution_runtime = Lazy(lambda: ExecutionRt(client, server, True))
|
||||
self.execution_runtime.val.auto_schedule_run = False
|
||||
|
||||
return self._run(False, *args)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
|
||||
) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
|
||||
"""
|
||||
Evaluate the function.
|
||||
|
||||
Args:
|
||||
*args (Value):
|
||||
argument(s) for evaluation
|
||||
|
||||
Returns:
|
||||
Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
|
||||
result(s) of evaluation or future of result(s) of evaluation if configured with async_run=True
|
||||
"""
|
||||
if isinstance(self.execution_runtime.val, ExecutionRt):
|
||||
auto_schedule_run = self.execution_runtime.val.auto_schedule_run
|
||||
else:
|
||||
auto_schedule_run = False # pragma: no cover
|
||||
return self._run(not auto_schedule_run, *args)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
sync: bool,
|
||||
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
|
||||
) -> Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
|
||||
"""
|
||||
Evaluate the function.
|
||||
|
||||
Args:
|
||||
*args (Value):
|
||||
argument(s) for evaluation
|
||||
|
||||
Returns:
|
||||
Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]:
|
||||
result(s) of evaluation if sync=True else future of result(s) of evaluation
|
||||
"""
|
||||
if self.configuration.simulate_encrypt_run_decrypt:
|
||||
return self._simulate_decrypt(self._simulate_run(*args)) # type: ignore
|
||||
return self.execution_runtime.val.server.run(
|
||||
|
||||
assert isinstance(self.execution_runtime.val, ExecutionRt)
|
||||
|
||||
fhe_work = lambda *args: self.execution_runtime.val.server.run(
|
||||
*args,
|
||||
evaluation_keys=self.execution_runtime.val.client.evaluation_keys,
|
||||
function_name=self.name,
|
||||
)
|
||||
|
||||
def args_ready(args):
|
||||
return [arg.result() if isinstance(arg, Future) else arg for arg in args]
|
||||
|
||||
if sync:
|
||||
return fhe_work(*args_ready(args))
|
||||
|
||||
all_args_done = all(not isinstance(arg, Future) or arg.done() for arg in args)
|
||||
|
||||
fhe_work_future = lambda *args: self.execution_runtime.val.fhe_executor_pool.submit(
|
||||
fhe_work, *args
|
||||
)
|
||||
if all_args_done:
|
||||
return fhe_work_future(*args_ready(args)) # type: ignore
|
||||
|
||||
# waiting args to be ready with async coroutines
|
||||
# it only required one thread to run unlimited waits vs unlimited sync threads
|
||||
async def wait_async(arg):
|
||||
if not isinstance(arg, Future):
|
||||
return arg # pragma: no cover
|
||||
if arg.done():
|
||||
return arg.result() # pragma: no cover
|
||||
return await asyncio.wrap_future(arg, loop=self.execution_runtime.val.fhe_waiter_loop)
|
||||
|
||||
async def args_ready_and_submit(*args):
|
||||
args = [await wait_async(arg) for arg in args]
|
||||
return await wait_async(fhe_work_future(*args))
|
||||
|
||||
run_async = args_ready_and_submit(*args)
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
run_async, self.execution_runtime.val.fhe_waiter_loop
|
||||
) # type: ignore
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
*results: Union[Value, Tuple[Value, ...]],
|
||||
self, *results: Union[Value, Tuple[Value, ...], Awaitable[Union[Value, Tuple[Value, ...]]]]
|
||||
) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
|
||||
"""
|
||||
Decrypt result(s) of evaluation.
|
||||
@@ -220,6 +348,8 @@ class FheFunction:
|
||||
if self.configuration.simulate_encrypt_run_decrypt:
|
||||
return tuple(results) if len(results) > 1 else results[0] # type: ignore
|
||||
|
||||
assert isinstance(self.execution_runtime.val, ExecutionRt)
|
||||
results = [res.result() if isinstance(res, Future) else res for res in results]
|
||||
return self.execution_runtime.val.client.decrypt(*results, function_name=self.name)
|
||||
|
||||
def encrypt_run_decrypt(self, *args: Any) -> Any:
|
||||
@@ -620,7 +750,9 @@ class FheModule:
|
||||
execution_client = Client(
|
||||
execution_server.client_specs, keyset_cache_directory, is_simulated=False
|
||||
)
|
||||
return ExecutionRt(execution_client, execution_server)
|
||||
return ExecutionRt(
|
||||
execution_client, execution_server, self.configuration.auto_schedule_run
|
||||
)
|
||||
|
||||
self.execution_runtime = Lazy(init_execution)
|
||||
if configuration.fhe_execution:
|
||||
|
||||
@@ -4,7 +4,9 @@ Tests of everything related to modules.
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
from concurrent.futures import Future
|
||||
from pathlib import Path
|
||||
from typing import Awaitable
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -955,3 +957,83 @@ def test_wired_with_all_encrypted_inputs(helpers):
|
||||
},
|
||||
helpers.configuration().fork(),
|
||||
)
|
||||
|
||||
|
||||
class IncDec:
|
||||
@fhe.module()
|
||||
class Module:
|
||||
@fhe.function({"x": "encrypted"})
|
||||
def inc(x):
|
||||
return fhe.refresh(x + 1)
|
||||
|
||||
@fhe.function({"x": "encrypted"})
|
||||
def dec(x):
|
||||
return fhe.refresh(x - 1)
|
||||
|
||||
precision = 4
|
||||
|
||||
inputset = list(range(1, 2**precision - 1))
|
||||
to_compile = {"inc": inputset, "dec": inputset}
|
||||
|
||||
|
||||
def test_run_async():
|
||||
"""
|
||||
Test `run_async` with `auto_schedule_run=False` configuration option.
|
||||
"""
|
||||
|
||||
module = IncDec.Module.compile(IncDec.to_compile)
|
||||
|
||||
sample_x = 2
|
||||
encrypted_x = module.inc.encrypt(sample_x)
|
||||
|
||||
a = module.inc.run_async(encrypted_x)
|
||||
assert isinstance(a, Future)
|
||||
|
||||
b = module.dec.run(a)
|
||||
assert isinstance(b, type(encrypted_x))
|
||||
|
||||
result = module.inc.decrypt(b)
|
||||
assert result == sample_x
|
||||
del module
|
||||
|
||||
|
||||
def test_run_sync():
|
||||
"""
|
||||
Test `run_sync` with `auto_schedule_run=True` configuration option.
|
||||
"""
|
||||
|
||||
conf = fhe.Configuration(auto_schedule_run=True)
|
||||
module = IncDec.Module.compile(IncDec.to_compile, conf)
|
||||
|
||||
sample_x = 2
|
||||
encrypted_x = module.inc.encrypt(sample_x)
|
||||
|
||||
a = module.inc.run(encrypted_x)
|
||||
assert isinstance(a, Future)
|
||||
|
||||
b = module.dec.run_sync(a)
|
||||
assert isinstance(b, type(encrypted_x))
|
||||
|
||||
result = module.inc.decrypt(b)
|
||||
assert result == sample_x
|
||||
|
||||
|
||||
def test_run_auto_schedule():
|
||||
"""
|
||||
Test `run` with `auto_schedule_run=True` configuration option.
|
||||
"""
|
||||
|
||||
conf = fhe.Configuration(auto_schedule_run=True)
|
||||
module = IncDec.Module.compile(IncDec.to_compile, conf)
|
||||
|
||||
sample_x = 2
|
||||
encrypted_x = module.inc.encrypt(sample_x)
|
||||
|
||||
a = module.inc.run(encrypted_x)
|
||||
assert isinstance(a, Future)
|
||||
|
||||
b = module.dec.run(a)
|
||||
assert isinstance(b, Future)
|
||||
|
||||
result = module.inc.decrypt(b)
|
||||
assert result == sample_x
|
||||
|
||||
Reference in New Issue
Block a user