mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(fhe_circuit): create FHECircuit class to combine operation graph and compiler engine
This commit is contained in:
57
concrete/common/fhe_circuit.py
Normal file
57
concrete/common/fhe_circuit.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Module to hold the result of compilation."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from zamalang import CompilerEngine
|
||||
|
||||
from .debugging import draw_graph, get_printable_graph
|
||||
from .operator_graph import OPGraph
|
||||
|
||||
|
||||
class FHECircuit:
|
||||
"""Class which is the result of compilation."""
|
||||
|
||||
opgraph: OPGraph
|
||||
engine: CompilerEngine
|
||||
|
||||
def __init__(self, opgraph: OPGraph, engine: CompilerEngine):
|
||||
self.opgraph = opgraph
|
||||
self.engine = engine
|
||||
|
||||
def __str__(self):
|
||||
return get_printable_graph(self.opgraph, show_data_types=True)
|
||||
|
||||
def draw(
|
||||
self,
|
||||
show: bool = False,
|
||||
vertical: bool = True,
|
||||
save_to: Optional[Path] = None,
|
||||
) -> str:
|
||||
"""Draw operation graph of the circuit and optionally save/show the drawing.
|
||||
|
||||
Args:
|
||||
show (bool): if set to True, the drawing will be shown using matplotlib
|
||||
vertical (bool): if set to True, the orientation will be vertical
|
||||
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path;
|
||||
otherwise it will be saved to a temporary file
|
||||
|
||||
Returns:
|
||||
str: path of the file where the drawn graph is saved
|
||||
|
||||
"""
|
||||
|
||||
return draw_graph(self.opgraph, show, vertical, save_to)
|
||||
|
||||
def run(self, *args: List[Union[int, List[int]]]) -> int:
|
||||
"""Encrypt, evaluate, and decrypt the inputs on the circuit.
|
||||
|
||||
Args:
|
||||
*args (List[Union[int, List[int]]]): inputs to the circuit
|
||||
|
||||
Returns:
|
||||
int: homomorphic evaluation result
|
||||
|
||||
"""
|
||||
|
||||
return self.engine.run(*args)
|
||||
@@ -11,6 +11,7 @@ from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_in
|
||||
from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
|
||||
from ..common.mlir.utils import (
|
||||
extend_direct_lookup_tables,
|
||||
@@ -237,7 +238,7 @@ def _compile_numpy_function_internal(
|
||||
compilation_configuration: CompilationConfiguration,
|
||||
compilation_artifacts: CompilationArtifacts,
|
||||
show_mlir: bool,
|
||||
) -> CompilerEngine:
|
||||
) -> FHECircuit:
|
||||
"""Compile an homomorphic program (internal part of the API).
|
||||
|
||||
Args:
|
||||
@@ -282,7 +283,7 @@ def _compile_numpy_function_internal(
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_result)
|
||||
|
||||
return engine
|
||||
return FHECircuit(op_graph, engine)
|
||||
|
||||
|
||||
def compile_numpy_function(
|
||||
@@ -292,7 +293,7 @@ def compile_numpy_function(
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
show_mlir: bool = False,
|
||||
) -> CompilerEngine:
|
||||
) -> FHECircuit:
|
||||
"""Compile an homomorphic program (main API).
|
||||
|
||||
Args:
|
||||
|
||||
@@ -22,13 +22,13 @@ x = hnp.EncryptedScalar(hnp.UnsignedInteger(2))
|
||||
y = hnp.EncryptedScalar(hnp.UnsignedInteger(1))
|
||||
|
||||
# Compile the function to its homomorphic equivalent
|
||||
engine = hnp.compile_numpy_function(
|
||||
circuit = hnp.compile_numpy_function(
|
||||
f, {"x": x, "y": y},
|
||||
[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (3, 0), (3, 1)],
|
||||
)
|
||||
|
||||
# Make homomorphic inference
|
||||
engine.run(1, 0)
|
||||
circuit.run(1, 0)
|
||||
```
|
||||
|
||||
## Overview
|
||||
|
||||
@@ -12,6 +12,10 @@ In this section we will go over some terms that we use throughout the project.
|
||||
- bounds
|
||||
- before intermediate representation is sent to the compiler, we need to know which node will output which type (e.g., uint3 vs uint5)
|
||||
- there are several ways to do this but the simplest one is to evaluate the intermediate representation with all combinations of inputs and remember the maximum and the minimum values for each node, which is what we call bounds, and bounds can be used to determine the appropriate type for each node
|
||||
- fhe circuit
|
||||
- it is the result of compilation
|
||||
- it contains the operation graph and the compiler engine in it
|
||||
- it has methods for printing, visualizing, and evaluating
|
||||
|
||||
## Module structure
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -41,7 +41,7 @@ Finally, we can compile our function to its homomorphic equivalent.
|
||||
|
||||
<!--python-test:cont-->
|
||||
```python
|
||||
engine = hnp.compile_numpy_function(
|
||||
circuit = hnp.compile_numpy_function(
|
||||
f, {"x": x, "y": y},
|
||||
inputset=inputset,
|
||||
)
|
||||
@@ -49,17 +49,17 @@ engine = hnp.compile_numpy_function(
|
||||
|
||||
## Performing homomorphic evaluation
|
||||
|
||||
You can use `.run(...)` method of `engine` returned by `hnp.compile_numpy_function(...)` to perform fully homomorphic evaluation. Here are some examples:
|
||||
You can use `.run(...)` method of `FHECircuit` returned by `hnp.compile_numpy_function(...)` to perform fully homomorphic evaluation. Here are some examples:
|
||||
|
||||
<!--python-test:cont-->
|
||||
```python
|
||||
engine.run(3, 4)
|
||||
circuit.run(3, 4)
|
||||
# 7
|
||||
engine.run(1, 2)
|
||||
circuit.run(1, 2)
|
||||
# 3
|
||||
engine.run(7, 7)
|
||||
circuit.run(7, 7)
|
||||
# 14
|
||||
engine.run(0, 0)
|
||||
circuit.run(0, 0)
|
||||
# 0
|
||||
```
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(3) == 45
|
||||
engine.run(0) == 42
|
||||
circuit.run(3) == 45
|
||||
circuit.run(0) == 42
|
||||
```
|
||||
|
||||
### Dynamic ClearScalar and EncryptedScalar
|
||||
@@ -52,8 +52,8 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(6, 4) == 10
|
||||
engine.run(1, 1) == 2
|
||||
circuit.run(6, 4) == 10
|
||||
circuit.run(1, 1) == 2
|
||||
```
|
||||
|
||||
where
|
||||
@@ -78,8 +78,8 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(7, 7) == 14
|
||||
engine.run(3, 4) == 7
|
||||
circuit.run(7, 7) == 14
|
||||
circuit.run(3, 4) == 7
|
||||
```
|
||||
|
||||
## Subtraction
|
||||
@@ -100,8 +100,8 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(2) == 1
|
||||
engine.run(3) == 0
|
||||
circuit.run(2) == 1
|
||||
circuit.run(3) == 0
|
||||
```
|
||||
|
||||
### Dynamic ClearScalar and EncryptedScalar
|
||||
@@ -121,8 +121,8 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(2, 4) == 2
|
||||
engine.run(1, 7) == 6
|
||||
circuit.run(2, 4) == 2
|
||||
circuit.run(1, 7) == 6
|
||||
```
|
||||
|
||||
## Multiplication
|
||||
@@ -151,8 +151,8 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(2) == 4
|
||||
engine.run(5) == 10
|
||||
circuit.run(2) == 4
|
||||
circuit.run(5) == 10
|
||||
```
|
||||
|
||||
### Dynamic ClearScalar and EncryptedScalar
|
||||
@@ -180,8 +180,8 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(2, 3) == 6
|
||||
engine.run(1, 7) == 7
|
||||
circuit.run(2, 3) == 6
|
||||
circuit.run(1, 7) == 7
|
||||
```
|
||||
|
||||
## Dot Product
|
||||
@@ -211,8 +211,8 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run([1, 1], [2, 3]) == 5
|
||||
engine.run([2, 3], [2, 3]) == 13
|
||||
circuit.run([1, 1], [2, 3]) == 5
|
||||
circuit.run([2, 3], [2, 3]) == 13
|
||||
```
|
||||
|
||||
## Combining all together
|
||||
@@ -233,6 +233,6 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run([1, 2], [4, 3], 10) == 60
|
||||
engine.run([2, 3], [3, 2], 5) == 66
|
||||
circuit.run([1, 2], [4, 3], 10) == 60
|
||||
circuit.run([2, 3], [3, 2], 5) == 66
|
||||
```
|
||||
|
||||
@@ -23,10 +23,10 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(0) == 2
|
||||
engine.run(1) == 1
|
||||
engine.run(2) == 3
|
||||
engine.run(3) == 0
|
||||
circuit.run(0) == 2
|
||||
circuit.run(1) == 1
|
||||
circuit.run(2) == 3
|
||||
circuit.run(3) == 0
|
||||
```
|
||||
|
||||
## Fused table lookup
|
||||
@@ -49,14 +49,14 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(0) == 77
|
||||
engine.run(1) == 35
|
||||
engine.run(2) == 32
|
||||
engine.run(3) == 70
|
||||
engine.run(4) == 115
|
||||
engine.run(5) == 125
|
||||
engine.run(6) == 91
|
||||
engine.run(7) == 45
|
||||
circuit.run(0) == 77
|
||||
circuit.run(1) == 35
|
||||
circuit.run(2) == 32
|
||||
circuit.run(3) == 70
|
||||
circuit.run(4) == 115
|
||||
circuit.run(5) == 125
|
||||
circuit.run(6) == 91
|
||||
circuit.run(7) == 45
|
||||
```
|
||||
|
||||
Initially, the function is converted to this operation graph
|
||||
|
||||
@@ -16,11 +16,11 @@ results in
|
||||
|
||||
<!--python-test:skip-->
|
||||
```python
|
||||
engine.run(3) == 27
|
||||
engine.run(0) == 0
|
||||
engine.run(1) == 90
|
||||
engine.run(10) == 91
|
||||
engine.run(60) == 58
|
||||
circuit.run(3) == 27
|
||||
circuit.run(0) == 0
|
||||
circuit.run(1) == 90
|
||||
circuit.run(10) == 91
|
||||
circuit.run(60) == 58
|
||||
```
|
||||
|
||||
## Supported operations
|
||||
|
||||
50
tests/common/test_fhe_circuit.py
Normal file
50
tests/common/test_fhe_circuit.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Test module for Circuit class"""
|
||||
|
||||
import filecmp
|
||||
|
||||
import concrete.numpy as hnp
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
|
||||
|
||||
def test_circuit_str():
|
||||
"""Test function for `__str__` method of `Circuit`"""
|
||||
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset)
|
||||
|
||||
assert str(circuit) == get_printable_graph(circuit.opgraph, show_data_types=True)
|
||||
|
||||
|
||||
def test_circuit_draw():
|
||||
"""Test function for `draw` method of `Circuit`"""
|
||||
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset)
|
||||
|
||||
assert filecmp.cmp(circuit.draw(), draw_graph(circuit.opgraph))
|
||||
assert filecmp.cmp(circuit.draw(vertical=False), draw_graph(circuit.opgraph, vertical=False))
|
||||
|
||||
|
||||
def test_circuit_run():
|
||||
"""Test function for `run` method of `Circuit`"""
|
||||
|
||||
def f(x):
|
||||
return x + 42
|
||||
|
||||
x = hnp.EncryptedScalar(hnp.UnsignedInteger(3))
|
||||
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset)
|
||||
|
||||
for x in inputset:
|
||||
assert circuit.run(*x) == circuit.engine.run(*x)
|
||||
Reference in New Issue
Block a user