mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-09 12:57:55 -05:00
feat(frontend): wire tracing
This commit is contained in:
committed by
Alexandre Péré
parent
5716912d1b
commit
e675a75285
@@ -1,6 +1,6 @@
|
||||
# Composing functions with modules
|
||||
|
||||
This document explains how to compile Fully Homomorphic Encryption (FHE) modules containing multiple functions using **Concrete**.
|
||||
This document explains how to compile Fully Homomorphic Encryption (FHE) modules containing multiple functions using **Concrete**.
|
||||
|
||||
Deploying a server that contains many compatible functions is important for some use cases. With **Concrete**, you can compile FHE modules containing as many functions as needed.
|
||||
|
||||
@@ -115,7 +115,7 @@ Encrypting initial values
|
||||
| 9 || 144 | 144 | 233 | 233 |
|
||||
```
|
||||
|
||||
## Iterations
|
||||
## Iterations
|
||||
|
||||
Modules support iteration with cleartext iterands to some extent, particularly for loops structured like this:
|
||||
|
||||
@@ -231,6 +231,8 @@ You have 3 options for the `composition` attribute:
|
||||
|
||||
3. **`fhe.Wired`**: This policy allows you to define custom composition rules. You can specify which outputs of a function can be forwarded to which inputs of another function.
|
||||
|
||||
Note that, in case of complex composition logic another option is to rely on [[composing_functions_with_modules#Automatic module tracing]] to automatically derive the composition from examples.
|
||||
|
||||
Here is an example:
|
||||
```python
|
||||
from concrete import fhe
|
||||
@@ -249,9 +251,9 @@ class Collatz:
|
||||
return ans, is_one
|
||||
|
||||
composition = Wired(
|
||||
[
|
||||
{
|
||||
Wire(Output(collatz, 0), Input(collatz, 0)
|
||||
]
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -260,19 +262,69 @@ In this case, the policy states that the first output of the `collatz` function
|
||||
You can use the `fhe.Wire` between any two functions. It is also possible to define wires with `fhe.AllInputs` and `fhe.AllOutputs` ends. For instance, in the previous example:
|
||||
```python
|
||||
composition = Wired(
|
||||
[
|
||||
{
|
||||
Wire(AllOutputs(collatz), AllInputs(collatz))
|
||||
]
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
This policy would be equivalent to using the `fhe.AllComposable` policy.
|
||||
|
||||
## Current limitations
|
||||
## Automatic module tracing
|
||||
|
||||
Depending on the functions, composition may add a significant overhead compared to a non-composable version.
|
||||
When a module's composition logic is static and straightforward, declaratively defining a `Wired` policy is usually the simplest approach. However, in cases where modules have more complex or dynamic composition logic, deriving an accurate list of `Wire` components to be used in the policy can become challenging.
|
||||
|
||||
To be composable, a function must meet the following condition: every output that can be forwarded as input (according to the composition policy) must contain a noise-refreshing operation. Since adding a noise refresh has a noticeable impact on performance, Concrete does not automatically include it.
|
||||
Another related problem is defining different function input-sets. When the composition logic is simple, these can be provided manually. But as the composition gets more convoluted, computing a consistent ensemble of inputsets for a module may become intractable.
|
||||
|
||||
For those advanced cases, you can derive the composition rules and the input-sets automatically from user-provided examples. Consider the following module:
|
||||
```python
|
||||
from concrete import fhe
|
||||
from fhe import Wired
|
||||
|
||||
@fhe.module()
|
||||
class MyModule:
|
||||
@fhe.function({"x": "encrypted"})
|
||||
def increment(x):
|
||||
return (x + 1) % 100
|
||||
|
||||
@fhe.function({"x": "encrypted"})
|
||||
def decrement(x):
|
||||
return (x - 1) % 100
|
||||
|
||||
@fhe.function({"x": "encrypted"})
|
||||
def decimate(x):
|
||||
return (x / 10) % 100
|
||||
|
||||
composition = fhe.Wired()
|
||||
```
|
||||
|
||||
You can use the `wire_pipeline` context manager to activate the module tracing functionality:
|
||||
```python
|
||||
# A single inputset used during tracing is defined
|
||||
inputset = [np.random.randint(1, 100, size=()) for _ in range(100)]
|
||||
|
||||
# The inputset is passed to the `wire_pipeline` method, which itself returns an iterator over the inputset samples.
|
||||
with MyModule.wire_pipeline(inputset) as samples_iter:
|
||||
|
||||
# The inputset is iterated over
|
||||
for s in samples_iter:
|
||||
|
||||
# Here we provide an example of how we expect the module functions to be used at runtime in fhe.
|
||||
Module.increment(Module.decimate(Module.decrement(s)))
|
||||
|
||||
# It is not needed to provide any inputsets to the `compile` method after tracing the wires, since those were already computed automatically during the module tracing.
|
||||
module = MyModule.compile(
|
||||
p_error=0.01,
|
||||
)
|
||||
```
|
||||
|
||||
Note that any dynamic branching is possible during module tracing. However, for complex runtime logic, ensure that the input set provides sufficient examples to cover all potential code paths.
|
||||
|
||||
## Current Limitations
|
||||
|
||||
Depending on the functions, composition may add a significant overhead compared to a non-composable version.
|
||||
|
||||
To be composable, a function must meet the following condition: every output that can be forwarded as input (according to the composition policy) must contain a noise-refreshing operation. Since adding a noise refresh has a noticeable impact on performance, Concrete does not automatically include it.
|
||||
|
||||
For instance, to implement a function that doubles an encrypted value, you might write:
|
||||
|
||||
@@ -283,18 +335,14 @@ class Doubler:
|
||||
def double(counter):
|
||||
return counter * 2
|
||||
```
|
||||
This function is valid with the `fhe.NotComposable` policy. However, if compiled with the `fhe.AllComposable` policy, it will raise a `RuntimeError: Program cannot be composed: ...`, indicating that an extra Programmable Bootstrapping (PBS) step must be added.
|
||||
This function is valid with the `fhe.NotComposable` policy. However, if compiled with the `fhe.AllComposable` policy, it will raise a `RuntimeError: Program cannot be composed: ...`, indicating that an extra Programmable Bootstrapping (PBS) step must be added.
|
||||
|
||||
To resolve this and make the circuit valid, add a PBS at the end of the circuit:
|
||||
|
||||
```python
|
||||
def noise_reset(x):
|
||||
return fhe.univariate(lambda x: x)(x)
|
||||
|
||||
@fhe.module()
|
||||
class Doubler:
|
||||
@fhe.compiler({"counter": "encrypted"})
|
||||
def double(counter):
|
||||
return noise_reset(counter * 2)
|
||||
return fhe.refresh(counter * 2)
|
||||
```
|
||||
|
||||
|
||||
@@ -117,6 +117,9 @@ class FheFunction:
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FheFunction({self.name=})"
|
||||
|
||||
def simulate(self, *args: Any) -> Any:
|
||||
"""
|
||||
Simulate execution of the function.
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import (
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
@@ -55,6 +56,7 @@ class FunctionDef:
|
||||
graph: Optional[Graph]
|
||||
_parameter_values: Dict[str, ValueDescription]
|
||||
location: str
|
||||
_trace_wires: Optional[Set["Wire"]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -115,6 +117,7 @@ class FunctionDef:
|
||||
self.location = (
|
||||
f"{self.function.__code__.co_filename}:{self.function.__code__.co_firstlineno}"
|
||||
)
|
||||
self._trace_wires = None
|
||||
|
||||
def trace(
|
||||
self,
|
||||
@@ -237,20 +240,84 @@ class FunctionDef:
|
||||
np.integer,
|
||||
np.floating,
|
||||
np.ndarray,
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
|
||||
"TracedOutput",
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray, "TracedOutput"], ...],
|
||||
]:
|
||||
if len(kwargs) != 0:
|
||||
message = f"Calling function '{self.function.__name__}' with kwargs is not supported"
|
||||
raise RuntimeError(message)
|
||||
|
||||
sample = args[0] if len(args) == 1 else args
|
||||
# The actual call to the function graph object gets wrapped between two calls to methods
|
||||
# that allows to trace the wiring.
|
||||
#
|
||||
# When activated:
|
||||
# + `_trace_wire_outputs` method wraps ciphered outputs into a `TracedOutput` object,
|
||||
# along with its origin information.
|
||||
# + `_trace_wire_inputs` method unwraps the `TracedOutput`, records the wiring, and
|
||||
# returns unwrapped values for execution.
|
||||
traced_inputs = self._trace_wires_inputs(*args)
|
||||
|
||||
if self.graph is None:
|
||||
self.trace(sample)
|
||||
# Note that the tracing must be executed on the `traced_inputs` which are unwrapped
|
||||
# from the potential `TracedOutput` added by wire tracing.
|
||||
self.trace(traced_inputs)
|
||||
assert self.graph is not None
|
||||
|
||||
self.inputset.append(sample)
|
||||
return self.graph(*args)
|
||||
self.inputset.append(traced_inputs)
|
||||
|
||||
if isinstance(traced_inputs, tuple):
|
||||
raw_outputs = self.graph(*traced_inputs)
|
||||
else:
|
||||
raw_outputs = self.graph(traced_inputs)
|
||||
|
||||
if isinstance(raw_outputs, tuple):
|
||||
traced_output = self._trace_wires_outputs(*raw_outputs)
|
||||
else:
|
||||
traced_output = self._trace_wires_outputs(raw_outputs)
|
||||
|
||||
return traced_output
|
||||
|
||||
def _trace_wires_inputs(
|
||||
self,
|
||||
*args: Any,
|
||||
) -> Union[
|
||||
np.bool_,
|
||||
np.integer,
|
||||
np.floating,
|
||||
np.ndarray,
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
|
||||
]:
|
||||
# If the _trace_wires property points to a wire list, we use wire tracing.
|
||||
if self._trace_wires is None:
|
||||
return args[0] if len(args) == 1 else args
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, TracedOutput):
|
||||
# Wire gets added to the wire list
|
||||
self._trace_wires.add(Wire(arg.output_info, Input(self, i)))
|
||||
|
||||
output = tuple(arg.returned_value if isinstance(arg, TracedOutput) else arg for arg in args)
|
||||
|
||||
return output[0] if len(output) == 1 else output
|
||||
|
||||
def _trace_wires_outputs(
|
||||
self,
|
||||
*args: Any,
|
||||
) -> Union[
|
||||
np.bool_,
|
||||
np.integer,
|
||||
np.floating,
|
||||
np.ndarray,
|
||||
"TracedOutput",
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray, "TracedOutput"], ...],
|
||||
]:
|
||||
# If the _trace_wires property points to a wire list, we use wire tracing.
|
||||
if self._trace_wires is None:
|
||||
return args[0] if len(args) == 1 else args
|
||||
|
||||
output = tuple(TracedOutput(Output(self, i), arg) for (i, arg) in enumerate(args))
|
||||
|
||||
return output[0] if len(output) == 1 else output
|
||||
|
||||
|
||||
class NotComposable:
|
||||
@@ -402,7 +469,7 @@ class Wired(NamedTuple):
|
||||
Composition policy which allows the forwarding of certain outputs to certain inputs.
|
||||
"""
|
||||
|
||||
wires: List[Wire]
|
||||
wires: Set[Wire] = set()
|
||||
|
||||
def get_rules_iter(self, _) -> Iterable[CompositionRule]:
|
||||
"""
|
||||
@@ -618,6 +685,39 @@ class DebugManager:
|
||||
pretty(module.statistics)
|
||||
|
||||
|
||||
class TracedOutput(NamedTuple):
|
||||
"""
|
||||
A wrapper type used to trace wiring.
|
||||
|
||||
Allows to tag an output value coming from an other module function, and binds it with
|
||||
information about its origin.
|
||||
"""
|
||||
|
||||
output_info: Output
|
||||
returned_value: Any
|
||||
|
||||
|
||||
class WireTracingContextManager:
|
||||
"""
|
||||
A context manager returned by the `wire_pipeline` method.
|
||||
|
||||
Activates wire tracing and yields an inputset that can be iterated on for tracing.
|
||||
"""
|
||||
|
||||
def __init__(self, module, inputset):
|
||||
self.module = module
|
||||
self.inputset = inputset
|
||||
|
||||
def __enter__(self):
|
||||
for func in self.module.functions.values():
|
||||
func._trace_wires = self.module.composition.wires
|
||||
return self.inputset
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _exc_tb):
|
||||
for func in self.module.functions.values():
|
||||
func._trace_wires = None
|
||||
|
||||
|
||||
class ModuleCompiler:
|
||||
"""
|
||||
Compiler class for multiple functions, to glue the compilation pipeline.
|
||||
@@ -637,6 +737,13 @@ class ModuleCompiler:
|
||||
self.compilation_context = CompilationContext.new()
|
||||
self.composition = composition
|
||||
|
||||
def wire_pipeline(self, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
|
||||
"""
|
||||
Return a context manager that traces wires automatically.
|
||||
"""
|
||||
self.composition = Wired(set())
|
||||
return WireTracingContextManager(self, inputset)
|
||||
|
||||
def compile(
|
||||
self,
|
||||
inputsets: Optional[Dict[str, Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]] = None,
|
||||
|
||||
@@ -294,7 +294,7 @@ class LevenshsteinModule:
|
||||
#
|
||||
# There is a single output of equal, it goes to input 0 of mix
|
||||
composition = fhe.Wired(
|
||||
[
|
||||
{
|
||||
fhe.Wire(fhe.AllOutputs(equal), fhe.Input(mix, 0)),
|
||||
fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 1)),
|
||||
fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 2)),
|
||||
@@ -304,7 +304,7 @@ class LevenshsteinModule:
|
||||
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 2)),
|
||||
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 3)),
|
||||
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 4)),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -486,10 +486,10 @@ def test_composition_policy_wires():
|
||||
return (x + y), (x - y)
|
||||
|
||||
composition = fhe.Wired(
|
||||
[
|
||||
{
|
||||
fhe.Wire(fhe.AllOutputs(add_sub), fhe.AllInputs(add_sub)),
|
||||
fhe.Wire(fhe.AllOutputs(add_sub), fhe.Input(square, 0)),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(Module.composition, fhe.CompositionPolicy)
|
||||
@@ -508,9 +508,9 @@ def test_composition_wired_enhances_complexity():
|
||||
return (x * 2) % 200
|
||||
|
||||
composition = fhe.Wired(
|
||||
[
|
||||
{
|
||||
fhe.Wire(fhe.Output(_1, 0), fhe.Input(_2, 0)),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
module1 = Module1.compile(
|
||||
@@ -558,10 +558,10 @@ def test_composition_wired_compilation():
|
||||
return (x * 2) % 100
|
||||
|
||||
composition = fhe.Wired(
|
||||
[
|
||||
{
|
||||
fhe.Wire(fhe.Output(a, 0), fhe.Input(b, 0)),
|
||||
fhe.Wire(fhe.Output(b, 0), fhe.Input(c, 0)),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
module = Module.compile(
|
||||
@@ -718,3 +718,79 @@ def test_client_server_api(helpers):
|
||||
output = client.decrypt(deserialized_result, function_name="inc")
|
||||
|
||||
assert output == 11
|
||||
|
||||
|
||||
def test_trace_wire_single_input_output(helpers):
|
||||
@fhe.module()
|
||||
class Module:
|
||||
@fhe.function({"x": "encrypted"})
|
||||
def a(x):
|
||||
return (x * 2) % 20
|
||||
|
||||
@fhe.function({"x": "encrypted"})
|
||||
def b(x):
|
||||
return (x * 2) % 50
|
||||
|
||||
@fhe.function({"x": "encrypted"})
|
||||
def c(x):
|
||||
return (x * 2) % 100
|
||||
|
||||
composition = fhe.Wired()
|
||||
|
||||
# `wire_pipeline` takes an inputset as input, activates a context in which wiring is recorded, and returns an iterator that can be used inside of the context to iterate over samples.
|
||||
with Module.wire_pipeline([np.random.randint(1, 20, size=()) for _ in range(100)]) as samples:
|
||||
for s in samples:
|
||||
Module.c(Module.b(Module.a(s)))
|
||||
|
||||
assert len(Module.composition.wires) == 2
|
||||
assert fhe.Wire(fhe.Output(Module.a, 0), fhe.Input(Module.b, 0)) in Module.composition.wires
|
||||
assert fhe.Wire(fhe.Output(Module.b, 0), fhe.Input(Module.c, 0)) in Module.composition.wires
|
||||
|
||||
module = Module.compile(
|
||||
p_error=0.01,
|
||||
)
|
||||
|
||||
inp_enc = module.a.encrypt(5)
|
||||
a_enc = module.a.run(inp_enc)
|
||||
assert module.a.decrypt(a_enc) == 10
|
||||
b_enc = module.b.run(a_enc)
|
||||
assert module.b.decrypt(b_enc) == 20
|
||||
c_enc = module.c.run(b_enc)
|
||||
assert module.c.decrypt(c_enc) == 40
|
||||
|
||||
|
||||
def test_trace_wires_multi_inputs_outputs(helpers):
|
||||
@fhe.module()
|
||||
class Module:
|
||||
|
||||
@fhe.function({"x": "encrypted", "y": "encrypted"})
|
||||
def a(x, y):
|
||||
return ((x + y) * 2) % 20, ((x - y) * 2) % 20
|
||||
|
||||
@fhe.function({"x": "encrypted", "y": "encrypted"})
|
||||
def b(x, y):
|
||||
return ((x + y) * 2) % 20, ((x - y) * 2) % 20
|
||||
|
||||
composition = fhe.Wired()
|
||||
|
||||
# `wire_pipeline` takes an inputset as input, activates a context in which wiring is recorded, and returns an iterator that can be used inside of the context to iterate over samples.
|
||||
with Module.wire_pipeline(
|
||||
[(np.random.randint(1, 20, size=()), np.random.randint(1, 20, size=())) for _ in range(100)]
|
||||
) as samples:
|
||||
for s in samples:
|
||||
output = Module.a(s[0], s[1])
|
||||
Module.b(*output)
|
||||
|
||||
assert len(Module.composition.wires) == 2
|
||||
assert fhe.Wire(fhe.Output(Module.a, 0), fhe.Input(Module.b, 0)) in Module.composition.wires
|
||||
assert fhe.Wire(fhe.Output(Module.a, 1), fhe.Input(Module.b, 1)) in Module.composition.wires
|
||||
|
||||
module = Module.compile(
|
||||
p_error=0.01,
|
||||
)
|
||||
|
||||
inp_enc = module.a.encrypt(5, 1)
|
||||
a_enc = module.a.run(*inp_enc)
|
||||
assert module.a.decrypt(a_enc) == (12, 8)
|
||||
b_enc = module.b.run(*a_enc)
|
||||
assert module.b.decrypt(b_enc) == (0, 8)
|
||||
|
||||
Reference in New Issue
Block a user