feat: overhaul virtual circuits

This commit is contained in:
Umut
2023-02-15 16:25:18 +01:00
parent 656761346a
commit d595e9e50f
5 changed files with 87 additions and 31 deletions

View File

@@ -2,6 +2,7 @@
Declaration of `Circuit` class.
"""
from copy import deepcopy
from typing import Any, Optional, Tuple, Union, cast
import numpy as np
@@ -9,6 +10,7 @@ from concrete.compiler import PublicArguments, PublicResult
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..mlir import GraphConverter
from ..representation import Graph
from .client import Client
from .configuration import Configuration
@@ -35,24 +37,26 @@ class Circuit:
self.mlir = mlir
if self.configuration.virtual:
assert_that(self.configuration.enable_unsafe_features)
return
self._initialize_client_and_server()
def _initialize_client_and_server(self):
input_signs = []
for i in range(len(graph.input_nodes)): # pylint: disable=consider-using-enumerate
input_value = graph.input_nodes[i].output
for i in range(len(self.graph.input_nodes)): # pylint: disable=consider-using-enumerate
input_value = self.graph.input_nodes[i].output
assert_that(isinstance(input_value.dtype, Integer))
input_dtype = cast(Integer, input_value.dtype)
input_signs.append(input_dtype.is_signed)
output_signs = []
for i in range(len(graph.output_nodes)): # pylint: disable=consider-using-enumerate
output_value = graph.output_nodes[i].output
for i in range(len(self.graph.output_nodes)): # pylint: disable=consider-using-enumerate
output_value = self.graph.output_nodes[i].output
assert_that(isinstance(output_value.dtype, Integer))
output_dtype = cast(Integer, output_value.dtype)
output_signs.append(output_dtype.is_signed)
self.server = Server.create(mlir, input_signs, output_signs, self.configuration)
self.server = Server.create(self.mlir, input_signs, output_signs, self.configuration)
keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:
@@ -65,6 +69,49 @@ class Circuit:
def __str__(self):
return self.graph.format()
def simulate(self, *args: Any) -> Any:
"""
Simulate execution of the circuit.
Args:
*args (Any):
inputs to the circuit
Returns:
Any:
result of the simulation
"""
p_error = self.p_error if not self.configuration.virtual else self.configuration.p_error
return self.graph(*args, p_error=p_error)
def enable_fhe(self):
"""
Enable fully homomorphic encryption features.
When called on a virtual circuit, it'll enable access to the following methods:
- encrypt
- run
- decrypt
- encrypt_run_decrypt
When called on a normal circuit, it'll do nothing.
Raises:
RuntimeError:
if the circuit is not supported in fhe
"""
if not self.configuration.virtual:
return
new_configuration = deepcopy(self.configuration)
new_configuration.virtual = False
self.configuration = new_configuration
self.mlir = GraphConverter.convert(self.graph)
self._initialize_client_and_server()
def keygen(self, force: bool = False):
"""
Generate keys required for homomorphic evaluation.
@@ -154,9 +201,6 @@ class Circuit:
clear result of homomorphic evaluation
"""
if self.configuration.virtual:
return self.graph(*args, p_error=self.configuration.p_error)
return self.decrypt(self.run(self.encrypt(*args)))
def cleanup(self):

View File

@@ -47,10 +47,6 @@ class Configuration:
message = "Insecure key cache cannot be used without enabling unsafe features"
raise RuntimeError(message)
if self.virtual:
message = "Virtual compilation is not allowed without enabling unsafe features"
raise RuntimeError(message)
if self.use_insecure_key_cache and self.insecure_key_cache_location is None:
message = "Insecure key cache cannot be enabled without specifying its location"
raise RuntimeError(message)

View File

@@ -321,7 +321,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
circuit = function.compile(inputset, configuration=configuration, virtual=True, p_error=p_error)
sample = np.random.randint(0, 2**bit_width, size=(sample_size,))
output = circuit.encrypt_run_decrypt(sample)
output = circuit.simulate(sample)
errors = 0
for i in range(sample_size):
@@ -358,3 +358,30 @@ def test_circuit_run_with_unused_arg(helpers):
assert circuit.encrypt_run_decrypt(10, 0) == 20
assert circuit.encrypt_run_decrypt(10, 10) == 20
assert circuit.encrypt_run_decrypt(10, 20) == 20
def test_circuit_virtual_then_fhe(helpers):
"""
Test compiling to virtual and then fhe.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)]
circuit = f.compile(inputset, configuration, virtual=True)
assert circuit.simulate(3, 5) == 8
circuit.enable_fhe()
assert circuit.simulate(3, 5) == 8
assert circuit.encrypt_run_decrypt(3, 5) == 8
circuit.enable_fhe()
assert circuit.simulate(3, 5) == 8
assert circuit.encrypt_run_decrypt(3, 5) == 8

View File

@@ -209,21 +209,6 @@ def test_compiler_bad_compile(helpers):
"(expected a tuple of 3 values got a tuple of 2 values)"
)
# with bad configuration
# ----------------------
with pytest.raises(RuntimeError) as excinfo:
compiler = Compiler(lambda x: x, {"x": "encrypted"})
compiler.compile(
range(10),
configuration.fork(enable_unsafe_features=False, use_insecure_key_cache=False),
virtual=True,
)
assert str(excinfo.value) == (
"Virtual compilation is not allowed without enabling unsafe features"
)
def test_compiler_virtual_compile(helpers):
"""
@@ -240,7 +225,7 @@ def test_compiler_virtual_compile(helpers):
inputset = [(100_000, 1_000_000)]
circuit = compiler.compile(inputset, configuration=configuration, virtual=True)
assert circuit.encrypt_run_decrypt(100_000, 1_000_000) == 100_000_000_000
assert circuit.simulate(100_000, 1_000_000) == 100_000_000_000
def test_compiler_compile_bad_inputset(helpers):

View File

@@ -261,7 +261,11 @@ class Helpers:
for i in range(retries):
expected = sanitize(function(*sample))
actual = sanitize(circuit.encrypt_run_decrypt(*sample))
actual = sanitize(
circuit.simulate(*sample)
if circuit.configuration.virtual
else circuit.encrypt_run_decrypt(*sample)
)
if all(np.array_equal(e, a) for e, a in zip(expected, actual)):
break