feat(frontend): wire tracing

This commit is contained in:
Alexandre Péré
2024-07-18 15:14:06 +02:00
committed by Alexandre Péré
parent 5716912d1b
commit e675a75285
5 changed files with 263 additions and 29 deletions

View File

@@ -1,6 +1,6 @@
# Composing functions with modules
This document explains how to compile Fully Homomorphic Encryption (FHE) modules containing multiple functions using **Concrete**.
This document explains how to compile Fully Homomorphic Encryption (FHE) modules containing multiple functions using **Concrete**.
Deploying a server that contains many compatible functions is important for some use cases. With **Concrete**, you can compile FHE modules containing as many functions as needed.
@@ -115,7 +115,7 @@ Encrypting initial values
| 9 || 144 | 144 | 233 | 233 |
```
## Iterations
## Iterations
Modules support iteration with cleartext iterands to some extent, particularly for loops structured like this:
@@ -231,6 +231,8 @@ You have 3 options for the `composition` attribute:
3. **`fhe.Wired`**: This policy allows you to define custom composition rules. You can specify which outputs of a function can be forwarded to which inputs of another function.
Note that, in case of complex composition logic another option is to rely on [[composing_functions_with_modules#Automatic module tracing]] to automatically derive the composition from examples.
Here is an example:
```python
from concrete import fhe
@@ -249,9 +251,9 @@ class Collatz:
return ans, is_one
composition = Wired(
[
{
Wire(Output(collatz, 0), Input(collatz, 0)
]
}
)
```
@@ -260,19 +262,69 @@ In this case, the policy states that the first output of the `collatz` function
You can use the `fhe.Wire` between any two functions. It is also possible to define wires with `fhe.AllInputs` and `fhe.AllOutputs` ends. For instance, in the previous example:
```python
composition = Wired(
[
{
Wire(AllOutputs(collatz), AllInputs(collatz))
]
}
)
```
This policy would be equivalent to using the `fhe.AllComposable` policy.
## Current limitations
## Automatic module tracing
Depending on the functions, composition may add a significant overhead compared to a non-composable version.
When a module's composition logic is static and straightforward, declaratively defining a `Wired` policy is usually the simplest approach. However, in cases where modules have more complex or dynamic composition logic, deriving an accurate list of `Wire` components to be used in the policy can become challenging.
To be composable, a function must meet the following condition: every output that can be forwarded as input (according to the composition policy) must contain a noise-refreshing operation. Since adding a noise refresh has a noticeable impact on performance, Concrete does not automatically include it.
Another related problem is defining different function input-sets. When the composition logic is simple, these can be provided manually. But as the composition gets more convoluted, computing a consistent ensemble of inputsets for a module may become intractable.
For those advanced cases, you can derive the composition rules and the input-sets automatically from user-provided examples. Consider the following module:
```python
from concrete import fhe
from fhe import Wired
@fhe.module()
class MyModule:
@fhe.function({"x": "encrypted"})
def increment(x):
return (x + 1) % 100
@fhe.function({"x": "encrypted"})
def decrement(x):
return (x - 1) % 100
@fhe.function({"x": "encrypted"})
def decimate(x):
return (x / 10) % 100
composition = fhe.Wired()
```
You can use the `wire_pipeline` context manager to activate the module tracing functionality:
```python
# A single inputset used during tracing is defined
inputset = [np.random.randint(1, 100, size=()) for _ in range(100)]
# The inputset is passed to the `wire_pipeline` method, which itself returns an iterator over the inputset samples.
with MyModule.wire_pipeline(inputset) as samples_iter:
# The inputset is iterated over
for s in samples_iter:
# Here we provide an example of how we expect the module functions to be used at runtime in fhe.
Module.increment(Module.decimate(Module.decrement(s)))
# It is not needed to provide any inputsets to the `compile` method after tracing the wires, since those were already computed automatically during the module tracing.
module = MyModule.compile(
p_error=0.01,
)
```
Note that any dynamic branching is possible during module tracing. However, for complex runtime logic, ensure that the input set provides sufficient examples to cover all potential code paths.
## Current Limitations
Depending on the functions, composition may add a significant overhead compared to a non-composable version.
To be composable, a function must meet the following condition: every output that can be forwarded as input (according to the composition policy) must contain a noise-refreshing operation. Since adding a noise refresh has a noticeable impact on performance, Concrete does not automatically include it.
For instance, to implement a function that doubles an encrypted value, you might write:
@@ -283,18 +335,14 @@ class Doubler:
def double(counter):
return counter * 2
```
This function is valid with the `fhe.NotComposable` policy. However, if compiled with the `fhe.AllComposable` policy, it will raise a `RuntimeError: Program cannot be composed: ...`, indicating that an extra Programmable Bootstrapping (PBS) step must be added.
This function is valid with the `fhe.NotComposable` policy. However, if compiled with the `fhe.AllComposable` policy, it will raise a `RuntimeError: Program cannot be composed: ...`, indicating that an extra Programmable Bootstrapping (PBS) step must be added.
To resolve this and make the circuit valid, add a PBS at the end of the circuit:
```python
def noise_reset(x):
return fhe.univariate(lambda x: x)(x)
@fhe.module()
class Doubler:
@fhe.compiler({"counter": "encrypted"})
def double(counter):
return noise_reset(counter * 2)
return fhe.refresh(counter * 2)
```

View File

@@ -117,6 +117,9 @@ class FheFunction:
def __str__(self):
return self.graph.format()
def __repr__(self) -> str:
return f"FheFunction({self.name=})"
def simulate(self, *args: Any) -> Any:
"""
Simulate execution of the function.

View File

@@ -18,6 +18,7 @@ from typing import (
NamedTuple,
Optional,
Protocol,
Set,
Tuple,
Union,
runtime_checkable,
@@ -55,6 +56,7 @@ class FunctionDef:
graph: Optional[Graph]
_parameter_values: Dict[str, ValueDescription]
location: str
_trace_wires: Optional[Set["Wire"]]
def __init__(
self,
@@ -115,6 +117,7 @@ class FunctionDef:
self.location = (
f"{self.function.__code__.co_filename}:{self.function.__code__.co_firstlineno}"
)
self._trace_wires = None
def trace(
self,
@@ -237,20 +240,84 @@ class FunctionDef:
np.integer,
np.floating,
np.ndarray,
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
"TracedOutput",
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray, "TracedOutput"], ...],
]:
if len(kwargs) != 0:
message = f"Calling function '{self.function.__name__}' with kwargs is not supported"
raise RuntimeError(message)
sample = args[0] if len(args) == 1 else args
# The actual call to the function graph object gets wrapped between two calls to methods
# that allows to trace the wiring.
#
# When activated:
# + `_trace_wire_outputs` method wraps ciphered outputs into a `TracedOutput` object,
# along with its origin information.
# + `_trace_wire_inputs` method unwraps the `TracedOutput`, records the wiring, and
# returns unwrapped values for execution.
traced_inputs = self._trace_wires_inputs(*args)
if self.graph is None:
self.trace(sample)
# Note that the tracing must be executed on the `traced_inputs` which are unwrapped
# from the potential `TracedOutput` added by wire tracing.
self.trace(traced_inputs)
assert self.graph is not None
self.inputset.append(sample)
return self.graph(*args)
self.inputset.append(traced_inputs)
if isinstance(traced_inputs, tuple):
raw_outputs = self.graph(*traced_inputs)
else:
raw_outputs = self.graph(traced_inputs)
if isinstance(raw_outputs, tuple):
traced_output = self._trace_wires_outputs(*raw_outputs)
else:
traced_output = self._trace_wires_outputs(raw_outputs)
return traced_output
def _trace_wires_inputs(
self,
*args: Any,
) -> Union[
np.bool_,
np.integer,
np.floating,
np.ndarray,
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
]:
# If the _trace_wires property points to a wire list, we use wire tracing.
if self._trace_wires is None:
return args[0] if len(args) == 1 else args
for i, arg in enumerate(args):
if isinstance(arg, TracedOutput):
# Wire gets added to the wire list
self._trace_wires.add(Wire(arg.output_info, Input(self, i)))
output = tuple(arg.returned_value if isinstance(arg, TracedOutput) else arg for arg in args)
return output[0] if len(output) == 1 else output
def _trace_wires_outputs(
self,
*args: Any,
) -> Union[
np.bool_,
np.integer,
np.floating,
np.ndarray,
"TracedOutput",
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray, "TracedOutput"], ...],
]:
# If the _trace_wires property points to a wire list, we use wire tracing.
if self._trace_wires is None:
return args[0] if len(args) == 1 else args
output = tuple(TracedOutput(Output(self, i), arg) for (i, arg) in enumerate(args))
return output[0] if len(output) == 1 else output
class NotComposable:
@@ -402,7 +469,7 @@ class Wired(NamedTuple):
Composition policy which allows the forwarding of certain outputs to certain inputs.
"""
wires: List[Wire]
wires: Set[Wire] = set()
def get_rules_iter(self, _) -> Iterable[CompositionRule]:
"""
@@ -618,6 +685,39 @@ class DebugManager:
pretty(module.statistics)
class TracedOutput(NamedTuple):
"""
A wrapper type used to trace wiring.
Allows to tag an output value coming from an other module function, and binds it with
information about its origin.
"""
output_info: Output
returned_value: Any
class WireTracingContextManager:
"""
A context manager returned by the `wire_pipeline` method.
Activates wire tracing and yields an inputset that can be iterated on for tracing.
"""
def __init__(self, module, inputset):
self.module = module
self.inputset = inputset
def __enter__(self):
for func in self.module.functions.values():
func._trace_wires = self.module.composition.wires
return self.inputset
def __exit__(self, _exc_type, _exc_value, _exc_tb):
for func in self.module.functions.values():
func._trace_wires = None
class ModuleCompiler:
"""
Compiler class for multiple functions, to glue the compilation pipeline.
@@ -637,6 +737,13 @@ class ModuleCompiler:
self.compilation_context = CompilationContext.new()
self.composition = composition
def wire_pipeline(self, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
"""
Return a context manager that traces wires automatically.
"""
self.composition = Wired(set())
return WireTracingContextManager(self, inputset)
def compile(
self,
inputsets: Optional[Dict[str, Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]] = None,

View File

@@ -294,7 +294,7 @@ class LevenshsteinModule:
#
# There is a single output of equal, it goes to input 0 of mix
composition = fhe.Wired(
[
{
fhe.Wire(fhe.AllOutputs(equal), fhe.Input(mix, 0)),
fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 1)),
fhe.Wire(fhe.AllOutputs(mix), fhe.Input(mix, 2)),
@@ -304,7 +304,7 @@ class LevenshsteinModule:
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 2)),
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 3)),
fhe.Wire(fhe.AllOutputs(constant), fhe.Input(mix, 4)),
]
}
)

View File

@@ -486,10 +486,10 @@ def test_composition_policy_wires():
return (x + y), (x - y)
composition = fhe.Wired(
[
{
fhe.Wire(fhe.AllOutputs(add_sub), fhe.AllInputs(add_sub)),
fhe.Wire(fhe.AllOutputs(add_sub), fhe.Input(square, 0)),
]
}
)
assert isinstance(Module.composition, fhe.CompositionPolicy)
@@ -508,9 +508,9 @@ def test_composition_wired_enhances_complexity():
return (x * 2) % 200
composition = fhe.Wired(
[
{
fhe.Wire(fhe.Output(_1, 0), fhe.Input(_2, 0)),
]
}
)
module1 = Module1.compile(
@@ -558,10 +558,10 @@ def test_composition_wired_compilation():
return (x * 2) % 100
composition = fhe.Wired(
[
{
fhe.Wire(fhe.Output(a, 0), fhe.Input(b, 0)),
fhe.Wire(fhe.Output(b, 0), fhe.Input(c, 0)),
]
}
)
module = Module.compile(
@@ -718,3 +718,79 @@ def test_client_server_api(helpers):
output = client.decrypt(deserialized_result, function_name="inc")
assert output == 11
def test_trace_wire_single_input_output(helpers):
@fhe.module()
class Module:
@fhe.function({"x": "encrypted"})
def a(x):
return (x * 2) % 20
@fhe.function({"x": "encrypted"})
def b(x):
return (x * 2) % 50
@fhe.function({"x": "encrypted"})
def c(x):
return (x * 2) % 100
composition = fhe.Wired()
# `wire_pipeline` takes an inputset as input, activates a context in which wiring is recorded, and returns an iterator that can be used inside of the context to iterate over samples.
with Module.wire_pipeline([np.random.randint(1, 20, size=()) for _ in range(100)]) as samples:
for s in samples:
Module.c(Module.b(Module.a(s)))
assert len(Module.composition.wires) == 2
assert fhe.Wire(fhe.Output(Module.a, 0), fhe.Input(Module.b, 0)) in Module.composition.wires
assert fhe.Wire(fhe.Output(Module.b, 0), fhe.Input(Module.c, 0)) in Module.composition.wires
module = Module.compile(
p_error=0.01,
)
inp_enc = module.a.encrypt(5)
a_enc = module.a.run(inp_enc)
assert module.a.decrypt(a_enc) == 10
b_enc = module.b.run(a_enc)
assert module.b.decrypt(b_enc) == 20
c_enc = module.c.run(b_enc)
assert module.c.decrypt(c_enc) == 40
def test_trace_wires_multi_inputs_outputs(helpers):
@fhe.module()
class Module:
@fhe.function({"x": "encrypted", "y": "encrypted"})
def a(x, y):
return ((x + y) * 2) % 20, ((x - y) * 2) % 20
@fhe.function({"x": "encrypted", "y": "encrypted"})
def b(x, y):
return ((x + y) * 2) % 20, ((x - y) * 2) % 20
composition = fhe.Wired()
# `wire_pipeline` takes an inputset as input, activates a context in which wiring is recorded, and returns an iterator that can be used inside of the context to iterate over samples.
with Module.wire_pipeline(
[(np.random.randint(1, 20, size=()), np.random.randint(1, 20, size=())) for _ in range(100)]
) as samples:
for s in samples:
output = Module.a(s[0], s[1])
Module.b(*output)
assert len(Module.composition.wires) == 2
assert fhe.Wire(fhe.Output(Module.a, 0), fhe.Input(Module.b, 0)) in Module.composition.wires
assert fhe.Wire(fhe.Output(Module.a, 1), fhe.Input(Module.b, 1)) in Module.composition.wires
module = Module.compile(
p_error=0.01,
)
inp_enc = module.a.encrypt(5, 1)
a_enc = module.a.run(*inp_enc)
assert module.a.decrypt(a_enc) == (12, 8)
b_enc = module.b.run(*a_enc)
assert module.b.decrypt(b_enc) == (0, 8)