mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: overhaul virtual circuits
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
Declaration of `Circuit` class.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -9,6 +10,7 @@ from concrete.compiler import PublicArguments, PublicResult
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..mlir import GraphConverter
|
||||
from ..representation import Graph
|
||||
from .client import Client
|
||||
from .configuration import Configuration
|
||||
@@ -35,24 +37,26 @@ class Circuit:
|
||||
self.mlir = mlir
|
||||
|
||||
if self.configuration.virtual:
|
||||
assert_that(self.configuration.enable_unsafe_features)
|
||||
return
|
||||
|
||||
self._initialize_client_and_server()
|
||||
|
||||
def _initialize_client_and_server(self):
|
||||
input_signs = []
|
||||
for i in range(len(graph.input_nodes)): # pylint: disable=consider-using-enumerate
|
||||
input_value = graph.input_nodes[i].output
|
||||
for i in range(len(self.graph.input_nodes)): # pylint: disable=consider-using-enumerate
|
||||
input_value = self.graph.input_nodes[i].output
|
||||
assert_that(isinstance(input_value.dtype, Integer))
|
||||
input_dtype = cast(Integer, input_value.dtype)
|
||||
input_signs.append(input_dtype.is_signed)
|
||||
|
||||
output_signs = []
|
||||
for i in range(len(graph.output_nodes)): # pylint: disable=consider-using-enumerate
|
||||
output_value = graph.output_nodes[i].output
|
||||
for i in range(len(self.graph.output_nodes)): # pylint: disable=consider-using-enumerate
|
||||
output_value = self.graph.output_nodes[i].output
|
||||
assert_that(isinstance(output_value.dtype, Integer))
|
||||
output_dtype = cast(Integer, output_value.dtype)
|
||||
output_signs.append(output_dtype.is_signed)
|
||||
|
||||
self.server = Server.create(mlir, input_signs, output_signs, self.configuration)
|
||||
self.server = Server.create(self.mlir, input_signs, output_signs, self.configuration)
|
||||
|
||||
keyset_cache_directory = None
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
@@ -65,6 +69,49 @@ class Circuit:
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
|
||||
def simulate(self, *args: Any) -> Any:
|
||||
"""
|
||||
Simulate execution of the circuit.
|
||||
|
||||
Args:
|
||||
*args (Any):
|
||||
inputs to the circuit
|
||||
|
||||
Returns:
|
||||
Any:
|
||||
result of the simulation
|
||||
"""
|
||||
|
||||
p_error = self.p_error if not self.configuration.virtual else self.configuration.p_error
|
||||
return self.graph(*args, p_error=p_error)
|
||||
|
||||
def enable_fhe(self):
|
||||
"""
|
||||
Enable fully homomorphic encryption features.
|
||||
|
||||
When called on a virtual circuit, it'll enable access to the following methods:
|
||||
- encrypt
|
||||
- run
|
||||
- decrypt
|
||||
- encrypt_run_decrypt
|
||||
|
||||
When called on a normal circuit, it'll do nothing.
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
if the circuit is not supported in fhe
|
||||
"""
|
||||
|
||||
if not self.configuration.virtual:
|
||||
return
|
||||
|
||||
new_configuration = deepcopy(self.configuration)
|
||||
new_configuration.virtual = False
|
||||
self.configuration = new_configuration
|
||||
|
||||
self.mlir = GraphConverter.convert(self.graph)
|
||||
self._initialize_client_and_server()
|
||||
|
||||
def keygen(self, force: bool = False):
|
||||
"""
|
||||
Generate keys required for homomorphic evaluation.
|
||||
@@ -154,9 +201,6 @@ class Circuit:
|
||||
clear result of homomorphic evaluation
|
||||
"""
|
||||
|
||||
if self.configuration.virtual:
|
||||
return self.graph(*args, p_error=self.configuration.p_error)
|
||||
|
||||
return self.decrypt(self.run(self.encrypt(*args)))
|
||||
|
||||
def cleanup(self):
|
||||
|
||||
@@ -47,10 +47,6 @@ class Configuration:
|
||||
message = "Insecure key cache cannot be used without enabling unsafe features"
|
||||
raise RuntimeError(message)
|
||||
|
||||
if self.virtual:
|
||||
message = "Virtual compilation is not allowed without enabling unsafe features"
|
||||
raise RuntimeError(message)
|
||||
|
||||
if self.use_insecure_key_cache and self.insecure_key_cache_location is None:
|
||||
message = "Insecure key cache cannot be enabled without specifying its location"
|
||||
raise RuntimeError(message)
|
||||
|
||||
@@ -321,7 +321,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
|
||||
circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error)
|
||||
|
||||
sample = np.random.randint(0, 2**bit_width, size=(sample_size,))
|
||||
output = circuit.encrypt_run_decrypt(sample)
|
||||
output = circuit.simulate(sample)
|
||||
|
||||
errors = 0
|
||||
for i in range(sample_size):
|
||||
@@ -358,3 +358,30 @@ def test_circuit_run_with_unused_arg(helpers):
|
||||
assert circuit.encrypt_run_decrypt(10, 0) == 20
|
||||
assert circuit.encrypt_run_decrypt(10, 10) == 20
|
||||
assert circuit.encrypt_run_decrypt(10, 20) == 20
|
||||
|
||||
|
||||
def test_circuit_virtual_then_fhe(helpers):
|
||||
"""
|
||||
Test compiling to virtual and then fhe.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset, configuration, virtual=True)
|
||||
|
||||
assert circuit.simulate(3, 5) == 8
|
||||
|
||||
circuit.enable_fhe()
|
||||
|
||||
assert circuit.simulate(3, 5) == 8
|
||||
assert circuit.encrypt_run_decrypt(3, 5) == 8
|
||||
|
||||
circuit.enable_fhe()
|
||||
|
||||
assert circuit.simulate(3, 5) == 8
|
||||
assert circuit.encrypt_run_decrypt(3, 5) == 8
|
||||
|
||||
@@ -209,21 +209,6 @@ def test_compiler_bad_compile(helpers):
|
||||
"(expected a tuple of 3 values got a tuple of 2 values)"
|
||||
)
|
||||
|
||||
# with bad configuration
|
||||
# ----------------------
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compiler = Compiler(lambda x: x, {"x": "encrypted"})
|
||||
compiler.compile(
|
||||
range(10),
|
||||
configuration.fork(enable_unsafe_features=False, use_insecure_key_cache=False),
|
||||
virtual=True,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Virtual compilation is not allowed without enabling unsafe features"
|
||||
)
|
||||
|
||||
|
||||
def test_compiler_virtual_compile(helpers):
|
||||
"""
|
||||
@@ -240,7 +225,7 @@ def test_compiler_virtual_compile(helpers):
|
||||
inputset = [(100_000, 1_000_000)]
|
||||
circuit = compiler.compile(inputset, configuration=configuration, virtual=True)
|
||||
|
||||
assert circuit.encrypt_run_decrypt(100_000, 1_000_000) == 100_000_000_000
|
||||
assert circuit.simulate(100_000, 1_000_000) == 100_000_000_000
|
||||
|
||||
|
||||
def test_compiler_compile_bad_inputset(helpers):
|
||||
|
||||
@@ -261,7 +261,11 @@ class Helpers:
|
||||
|
||||
for i in range(retries):
|
||||
expected = sanitize(function(*sample))
|
||||
actual = sanitize(circuit.encrypt_run_decrypt(*sample))
|
||||
actual = sanitize(
|
||||
circuit.simulate(*sample)
|
||||
if circuit.configuration.virtual
|
||||
else circuit.encrypt_run_decrypt(*sample)
|
||||
)
|
||||
|
||||
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user