refactor: move configuration and artifacts to compile and trace methods

This commit is contained in:
Umut
2022-04-28 13:01:59 +02:00
parent cc726154b6
commit ffe26aadcb
25 changed files with 256 additions and 186 deletions

View File

@@ -50,14 +50,13 @@ class Circuit:
graph: Graph,
mlir: str,
configuration: Optional[Configuration] = None,
virtual: bool = False,
):
configuration = configuration if configuration is not None else Configuration()
self.graph = graph
self.mlir = mlir
self.virtual = virtual
self.virtual = configuration.virtual
if self.virtual:
return

View File

@@ -5,6 +5,7 @@ Declaration of `Compiler` class.
import inspect
import os
import traceback
from copy import deepcopy
from enum import Enum, unique
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
@@ -48,8 +49,6 @@ class Compiler:
self,
function: Callable,
parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]],
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
):
signature = inspect.signature(function)
@@ -82,16 +81,12 @@ class Compiler:
for param, status in parameter_encryption_statuses.items()
}
self.configuration = configuration if configuration is not None else Configuration()
self.artifacts = artifacts if artifacts is not None else DebugArtifacts()
self.configuration = Configuration()
self.artifacts = DebugArtifacts()
self.inputset = []
self.graph = None
self.artifacts.add_source_code(function)
for param, encryption_status in parameter_encryption_statuses.items():
self.artifacts.add_parameter_encryption_status(param, encryption_status)
def __call__(
self,
*args: Any,
@@ -126,6 +121,10 @@ class Compiler:
sample to use for tracing
"""
self.artifacts.add_source_code(self.function)
for param, encryption_status in self.parameter_encryption_statuses.items():
self.artifacts.add_parameter_encryption_status(param, encryption_status)
parameters = {
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
for arg, (param, status) in zip(
@@ -180,7 +179,9 @@ class Compiler:
def trace(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Graph:
"""
Trace the function using an inputset.
@@ -189,19 +190,37 @@ class Compiler:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
show_graph (bool, default = False):
whether to print the computation graph
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Graph:
computation graph representing the function prior to MLIR conversion
"""
old_configuration = deepcopy(self.configuration)
old_artifacts = deepcopy(self.artifacts)
if configuration is not None:
self.configuration = configuration
if artifacts is not None:
self.artifacts = artifacts
if len(kwargs) != 0:
self.configuration = self.configuration.fork(**kwargs)
try:
self._evaluate("Tracing", inputset)
assert self.graph is not None
if show_graph:
if self.configuration.verbose or self.configuration.show_graph:
graph = self.graph.format()
longest_line = max([len(line) for line in graph.split("\n")])
@@ -247,12 +266,19 @@ class Compiler:
raise
finally:
self.configuration = old_configuration
self.artifacts = old_artifacts
# pylint: disable=too-many-branches,too-many-statements
def compile(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
show_mlir: bool = False,
virtual: bool = False,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Circuit:
"""
Compile the function using an inputset.
@@ -261,34 +287,50 @@ class Compiler:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
show_graph (bool, default = False):
whether to print the computation graph
configuration(Optional[Configuration], default = None):
configuration to use
show_mlir (bool, default = False):
whether to print the compiled mlir
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
virtual (bool, default = False):
whether to simulate the computation to allow large bit-widths
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Circuit:
compiled circuit
"""
old_configuration = deepcopy(self.configuration)
old_artifacts = deepcopy(self.artifacts)
if configuration is not None:
self.configuration = configuration
if artifacts is not None:
self.artifacts = artifacts
if len(kwargs) != 0:
self.configuration = self.configuration.fork(**kwargs)
try:
if virtual and not self.configuration.enable_unsafe_features:
raise RuntimeError(
"Virtual compilation is not allowed without enabling unsafe features"
)
self._evaluate("Compiling", inputset)
assert self.graph is not None
mlir = GraphConverter.convert(self.graph, virtual=virtual)
mlir = GraphConverter.convert(self.graph, virtual=self.configuration.virtual)
self.artifacts.add_mlir_to_compile(mlir)
if show_graph or show_mlir:
graph = self.graph.format() if show_graph else ""
if (
self.configuration.verbose
or self.configuration.show_graph
or self.configuration.show_mlir
):
graph = (
self.graph.format()
if self.configuration.verbose or self.configuration.show_graph
else ""
)
longest_graph_line = max([len(line) for line in graph.split("\n")])
longest_mlir_line = max([len(line) for line in mlir.split("\n")])
@@ -308,7 +350,7 @@ class Compiler:
except OSError: # pragma: no cover
columns = min(longest_line, 80)
if show_graph:
if self.configuration.verbose or self.configuration.show_graph:
print()
print("Computation Graph")
@@ -318,8 +360,13 @@ class Compiler:
print()
if show_mlir:
print("\n" if not show_graph else "", end="")
if self.configuration.verbose or self.configuration.show_mlir:
print(
"\n"
if not (self.configuration.verbose or self.configuration.show_graph)
else "",
end="",
)
print("MLIR")
print("-" * columns)
@@ -328,7 +375,7 @@ class Compiler:
print()
return Circuit(self.graph, mlir, self.configuration, virtual=virtual)
return Circuit(self.graph, mlir, self.configuration)
except Exception: # pragma: no cover
@@ -346,3 +393,10 @@ class Compiler:
f.write(traceback.format_exc())
raise
finally:
self.configuration = old_configuration
self.artifacts = old_artifacts
# pylint: enable=too-many-branches,too-many-statements

View File

@@ -13,31 +13,63 @@ class Configuration:
Configuration class, to allow the compilation process to be customized.
"""
verbose: bool
show_graph: bool
show_mlir: bool
dump_artifacts_on_unexpected_failures: bool
enable_unsafe_features: bool
virtual: bool
use_insecure_key_cache: bool
loop_parallelize: bool
dataflow_parallelize: bool
auto_parallelize: bool
def _validate(self):
"""
Validate configuration.
"""
if not self.enable_unsafe_features:
if self.use_insecure_key_cache:
raise RuntimeError(
"Insecure key cache cannot be used without enabling unsafe features"
)
if self.virtual:
raise RuntimeError(
"Virtual compilation is not allowed without enabling unsafe features"
)
# pylint: disable=too-many-arguments
def __init__(
self,
verbose: bool = False,
show_graph: bool = False,
show_mlir: bool = False,
dump_artifacts_on_unexpected_failures: bool = True,
enable_unsafe_features: bool = False,
virtual: bool = False,
use_insecure_key_cache: bool = False,
loop_parallelize: bool = True,
dataflow_parallelize: bool = False,
auto_parallelize: bool = False,
):
self.verbose = verbose
self.show_graph = show_graph
self.show_mlir = show_mlir
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
self.enable_unsafe_features = enable_unsafe_features
self.virtual = virtual
self.use_insecure_key_cache = use_insecure_key_cache
self.loop_parallelize = loop_parallelize
self.dataflow_parallelize = dataflow_parallelize
self.auto_parallelize = auto_parallelize
if not enable_unsafe_features and use_insecure_key_cache:
raise RuntimeError("Insecure key cache cannot be used without enabling unsafe features")
self._validate()
# pylint: enable=too-many-arguments
@staticmethod
def insecure_key_cache_location() -> Optional[str]:
@@ -80,4 +112,8 @@ class Configuration:
setattr(result, name, value)
# pylint: disable=protected-access
result._validate()
# pylint: enable=protected-access
return result

View File

@@ -11,23 +11,13 @@ from .compiler import Compiler, EncryptionStatus
from .configuration import Configuration
def compiler(
parameters: Mapping[str, EncryptionStatus],
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
):
def compiler(parameters: Mapping[str, EncryptionStatus]):
"""
Provide an easy interface for compilation.
Args:
parameters (Dict[str, EncryptionStatus]):
encryption statuses of the parameters of the function to compile
configuration(Optional[Configuration], default = None):
configuration to use for compilation
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about compilation
"""
def decoration(function: Callable):
@@ -41,12 +31,7 @@ def compiler(
def __init__(self, function: Callable):
self.function = function # type: ignore
self.compiler = Compiler(
self.function,
dict(parameters),
configuration,
artifacts,
)
self.compiler = Compiler(self.function, dict(parameters))
def __call__(self, *args, **kwargs) -> Any:
self.compiler(*args, **kwargs)
@@ -55,7 +40,9 @@ def compiler(
def trace(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Graph:
"""
Trace the function into computation graph.
@@ -64,22 +51,28 @@ def compiler(
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
show_graph (bool, default = False):
whether to print the computation graph
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Graph:
computation graph representing the function prior to MLIR conversion
"""
return self.compiler.trace(inputset, show_graph)
return self.compiler.trace(inputset, configuration, artifacts, **kwargs)
def compile(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
show_mlir: bool = False,
virtual: bool = False,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Circuit:
"""
Compile the function into a circuit.
@@ -88,21 +81,21 @@ def compiler(
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
show_graph (bool, default = False):
whether to print the computation graph
configuration(Optional[Configuration], default = None):
configuration to use
show_mlir (bool, default = False):
whether to print the compiled mlir
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
virtual (bool, default = False):
whether to simulate the computation to allow large bit-widths
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Circuit:
compiled circuit
"""
return self.compiler.compile(inputset, show_graph, show_mlir, virtual)
return self.compiler.compile(inputset, configuration, artifacts, **kwargs)
return Compilable(function)

View File

@@ -121,11 +121,11 @@ import pathlib
artifacts = cnp.DebugArtifacts("/tmp/custom/export/path")
@cnp.compiler({"x": "encrypted"}, artifacts=artifacts)
@cnp.compiler({"x": "encrypted"})
def f(x):
return 127 - (50 * (np.sin(x) + 1)).astype(np.int64)
f.compile(range(2 ** 3))
f.compile(range(2 ** 3), artifacts=artifacts)
artifacts.export()
```

View File

@@ -19,12 +19,12 @@ def test_artifacts_export(helpers):
configuration = helpers.configuration()
artifacts = DebugArtifacts(tmpdir)
@compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts)
@compiler({"x": "encrypted"})
def f(x):
return x + 10
inputset = range(100)
f.compile(inputset)
f.compile(inputset, configuration, artifacts)
artifacts.export()

View File

@@ -18,12 +18,12 @@ def test_circuit_str(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
circuit = f.compile(inputset)
circuit = f.compile(inputset, configuration)
assert str(circuit) == (
"""
@@ -44,12 +44,12 @@ def test_circuit_draw(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
circuit = f.compile(inputset)
circuit = f.compile(inputset, configuration)
with tempfile.TemporaryDirectory() as path:
tmpdir = Path(path)
@@ -67,12 +67,12 @@ def test_circuit_bad_run(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
circuit = f.compile(inputset)
circuit = f.compile(inputset, configuration)
# with 1 argument
# ---------------
@@ -138,12 +138,12 @@ def test_circuit_virtual_explicit_api(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
circuit = f.compile(inputset, virtual=True)
circuit = f.compile(inputset, configuration, virtual=True)
with pytest.raises(RuntimeError) as excinfo:
circuit.keygen()

View File

@@ -7,13 +7,11 @@ import pytest
from concrete.numpy.compilation import Compiler
def test_compiler_bad_init(helpers):
def test_compiler_bad_init():
"""
Test `__init__` method of `Compiler` class with bad parameters.
"""
configuration = helpers.configuration()
def f(x, y, z):
return x + y + z
@@ -21,7 +19,7 @@ def test_compiler_bad_init(helpers):
# -----------
with pytest.raises(ValueError) as excinfo:
Compiler(f, {}, configuration=configuration)
Compiler(f, {})
assert str(excinfo.value) == (
"Encryption statuses of parameters 'x', 'y' and 'z' of function 'f' are not provided"
@@ -31,7 +29,7 @@ def test_compiler_bad_init(helpers):
# ---------------
with pytest.raises(ValueError) as excinfo:
Compiler(f, {"z": "clear"}, configuration=configuration)
Compiler(f, {"z": "clear"})
assert str(excinfo.value) == (
"Encryption statuses of parameters 'x' and 'y' of function 'f' are not provided"
@@ -41,7 +39,7 @@ def test_compiler_bad_init(helpers):
# ---------
with pytest.raises(ValueError) as excinfo:
Compiler(f, {"y": "encrypted", "z": "clear"}, configuration=configuration)
Compiler(f, {"y": "encrypted", "z": "clear"})
assert str(excinfo.value) == (
"Encryption status of parameter 'x' of function 'f' is not provided"
@@ -52,29 +50,19 @@ def test_compiler_bad_init(helpers):
# this is fine and `p` is just ignored
Compiler(
f,
{"x": "encrypted", "y": "encrypted", "z": "clear", "p": "clear"},
configuration=configuration,
)
Compiler(f, {"x": "encrypted", "y": "encrypted", "z": "clear", "p": "clear"})
def test_compiler_bad_call(helpers):
def test_compiler_bad_call():
"""
Test `__call__` method of `Compiler` class with bad parameters.
"""
configuration = helpers.configuration()
def f(x, y, z):
return x + y + z
with pytest.raises(RuntimeError) as excinfo:
compiler = Compiler(
f,
{"x": "encrypted", "y": "encrypted", "z": "clear"},
configuration=configuration,
)
compiler = Compiler(f, {"x": "encrypted", "y": "encrypted", "z": "clear"})
compiler(1, 2, 3, invalid=4)
assert str(excinfo.value) == "Calling function 'f' with kwargs is not supported"
@@ -94,9 +82,8 @@ def test_compiler_bad_trace(helpers):
compiler = Compiler(
f,
{"x": "encrypted", "y": "encrypted", "z": "clear"},
configuration=configuration,
)
compiler.trace()
compiler.trace(configuration=configuration)
assert str(excinfo.value) == "Tracing function 'f' without an inputset is not supported"
@@ -115,17 +102,18 @@ def test_compiler_bad_compile(helpers):
compiler = Compiler(
f,
{"x": "encrypted", "y": "encrypted", "z": "clear"},
configuration=configuration,
)
compiler.compile()
compiler.compile(configuration=configuration)
assert str(excinfo.value) == "Compiling function 'f' without an inputset is not supported"
configuration.enable_unsafe_features = False
with pytest.raises(RuntimeError) as excinfo:
compiler = Compiler(lambda x: x, {"x": "encrypted"}, configuration=configuration)
compiler.compile(virtual=True)
compiler = Compiler(lambda x: x, {"x": "encrypted"})
compiler.compile(
range(10),
configuration.fork(enable_unsafe_features=False, use_insecure_key_cache=False),
virtual=True,
)
assert str(excinfo.value) == (
"Virtual compilation is not allowed without enabling unsafe features"
@@ -142,7 +130,7 @@ def test_compiler_virtual_compile(helpers):
def f(x):
return x + 400
compiler = Compiler(f, {"x": "encrypted"}, configuration=configuration)
circuit = compiler.compile(inputset=range(400), virtual=True)
compiler = Compiler(f, {"x": "encrypted"})
circuit = compiler.compile(inputset=range(400), configuration=configuration, virtual=True)
assert circuit.encrypt_run_decrypt(200) == 600

View File

@@ -12,14 +12,14 @@ def test_call_compile(helpers):
configuration = helpers.configuration()
@compiler({"x": "encrypted"}, configuration=configuration)
@compiler({"x": "encrypted"})
def function(x):
return x + 42
for i in range(10):
function(i)
circuit = function.compile()
circuit = function.compile(configuration=configuration)
sample = 5
helpers.check_execution(circuit, function, sample)
@@ -33,12 +33,12 @@ def test_compiler_verbose_trace(helpers, capsys):
configuration = helpers.configuration()
artifacts = DebugArtifacts()
@compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts)
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
function.trace(inputset, show_graph=True)
function.trace(inputset, configuration, artifacts, show_graph=True)
captured = capsys.readouterr()
assert captured.out.strip() == (
@@ -61,12 +61,12 @@ def test_compiler_verbose_compile(helpers, capsys):
configuration = helpers.configuration()
artifacts = DebugArtifacts()
@compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts)
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
function.compile(inputset, show_graph=True, show_mlir=True)
function.compile(inputset, configuration, artifacts, show_graph=True, show_mlir=True)
captured = capsys.readouterr()
assert captured.out.strip() == (

View File

@@ -59,10 +59,10 @@ def test_constant_add(function, parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)
@@ -150,10 +150,10 @@ def test_add(function, parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -167,10 +167,10 @@ def test_concatenate(function, parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -57,12 +57,12 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers
else:
bias = None
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def function(x):
return connx.conv(x, weight, bias, strides=strides, dilations=dilations)
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
circuit = function.compile(inputset)
circuit = function.compile(inputset, configuration)
sample = np.random.randint(0, 4, size=input_shape, dtype=np.uint8)
helpers.check_execution(circuit, function, sample)
@@ -373,7 +373,7 @@ def test_bad_conv_compilation(
else:
bias = None
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def function(x):
return connx.conv(
x,
@@ -389,7 +389,7 @@ def test_bad_conv_compilation(
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
with pytest.raises(expected_error) as excinfo:
function.compile(inputset)
function.compile(inputset, configuration)
assert str(excinfo.value) == expected_message

View File

@@ -166,10 +166,10 @@ def test_direct_table_lookup(bits, function, helpers):
# scalar
# ------
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = range(2 ** bits)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = int(np.random.randint(0, 2 ** bits))
helpers.check_execution(circuit, function, sample, retries=10)
@@ -177,10 +177,10 @@ def test_direct_table_lookup(bits, function, helpers):
# tensor
# ------
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = [np.random.randint(0, 2 ** bits, size=(3, 2), dtype=np.uint8) for _ in range(100)]
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(0, 2 ** bits, size=(3, 2), dtype=np.uint8)
helpers.check_execution(circuit, function, sample, retries=10)
@@ -207,10 +207,10 @@ def test_direct_multi_table_lookup(helpers):
def function(x):
return table[x]
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = [np.random.randint(0, 2 ** 2, size=(3, 2), dtype=np.uint8) for _ in range(100)]
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(0, 2 ** 2, size=(3, 2), dtype=np.uint8)
helpers.check_execution(circuit, function, sample, retries=10)
@@ -277,22 +277,22 @@ def test_bad_direct_table_lookup(helpers):
# compilation with float value
# ----------------------------
compiler = cnp.Compiler(random_table_lookup_3b, {"x": "encrypted"}, configuration)
compiler = cnp.Compiler(random_table_lookup_3b, {"x": "encrypted"})
inputset = [1.5]
with pytest.raises(ValueError) as excinfo:
compiler.compile(inputset)
compiler.compile(inputset, configuration)
assert str(excinfo.value) == "LookupTable cannot be looked up with EncryptedScalar<float64>"
# compilation with invalid shape
# ------------------------------
compiler = cnp.Compiler(lambda x: table[x], {"x": "encrypted"}, configuration)
compiler = cnp.Compiler(lambda x: table[x], {"x": "encrypted"})
inputset = [10, 5, 6, 2]
with pytest.raises(ValueError) as excinfo:
compiler.compile(inputset)
compiler.compile(inputset, configuration)
assert str(excinfo.value) == (
"LookupTable of shape (3, 2) cannot be looked up with EncryptedScalar<uint4>"

View File

@@ -22,23 +22,23 @@ def test_dot(size, helpers):
bound = int(np.floor(np.sqrt(127 / size)))
cst = np.random.randint(0, bound, size=(size,))
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def left_function(x):
return np.dot(x, cst)
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def right_function(x):
return np.dot(cst, x)
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def method(x):
return x.dot(cst)
inputset = [np.random.randint(0, bound, size=(size,)) for i in range(100)]
left_function_circuit = left_function.compile(inputset)
right_function_circuit = right_function.compile(inputset)
method_circuit = method.compile(inputset)
left_function_circuit = left_function.compile(inputset, configuration)
right_function_circuit = right_function.compile(inputset, configuration)
method_circuit = method.compile(inputset, configuration)
sample = np.random.randint(0, bound, size=(size,), dtype=np.uint8)

View File

@@ -125,29 +125,29 @@ def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
lhs_cst = list(np.random.randint(minimum, maximum, size=lhs_shape))
rhs_cst = list(np.random.randint(minimum, maximum, size=rhs_shape))
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def lhs_operator(x):
return x @ rhs_cst
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def rhs_operator(x):
return lhs_cst @ x
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def lhs_function(x):
return np.matmul(x, rhs_cst)
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def rhs_function(x):
return np.matmul(lhs_cst, x)
lhs_inputset = [np.random.randint(minimum, maximum, size=lhs_shape) for i in range(100)]
rhs_inputset = [np.random.randint(minimum, maximum, size=rhs_shape) for i in range(100)]
lhs_operator_circuit = lhs_operator.compile(lhs_inputset)
rhs_operator_circuit = rhs_operator.compile(rhs_inputset)
lhs_function_circuit = lhs_function.compile(lhs_inputset)
rhs_function_circuit = rhs_function.compile(rhs_inputset)
lhs_operator_circuit = lhs_operator.compile(lhs_inputset, configuration)
rhs_operator_circuit = rhs_operator.compile(rhs_inputset, configuration)
lhs_function_circuit = lhs_function.compile(lhs_inputset, configuration)
rhs_function_circuit = rhs_function.compile(rhs_inputset, configuration)
lhs_sample = np.random.randint(minimum, maximum, size=lhs_shape, dtype=np.uint8)
rhs_sample = np.random.randint(minimum, maximum, size=rhs_shape, dtype=np.uint8)

View File

@@ -67,10 +67,10 @@ def test_constant_mul(function, parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -27,18 +27,18 @@ def test_neg(parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
@cnp.compiler(parameter_encryption_statuses, configuration=configuration)
@cnp.compiler(parameter_encryption_statuses)
def operator(x):
return -x
@cnp.compiler(parameter_encryption_statuses, configuration=configuration)
@cnp.compiler(parameter_encryption_statuses)
def function(x):
return np.negative(x)
inputset = helpers.generate_inputset(parameters)
operator_circuit = operator.compile(inputset)
function_circuit = function.compile(inputset)
operator_circuit = operator.compile(inputset, configuration)
function_circuit = function.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)

View File

@@ -447,10 +447,10 @@ def test_others(function, parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample, retries=10)
@@ -464,10 +464,10 @@ def test_others(function, parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample, retries=10)
@@ -483,13 +483,13 @@ def test_others_bad_fusing(helpers):
# two variable inputs
# -------------------
@cnp.compiler({"x": "encrypted", "y": "clear"}, configuration=configuration)
@cnp.compiler({"x": "encrypted", "y": "clear"})
def function1(x, y):
return (10 * (np.sin(x) ** 2) + 10 * (np.cos(y) ** 2)).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
inputset = [(i, i) for i in range(100)]
function1.compile(inputset)
function1.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
@@ -528,13 +528,13 @@ return %13
# big intermediate constants
# --------------------------
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def function2(x):
return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
inputset = range(100)
function2.compile(inputset)
function2.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long
@@ -559,13 +559,13 @@ return %4
# intermediates with different shape
# ----------------------------------
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def function3(x):
return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
inputset = [np.random.randint(0, 2 ** 7, size=(3, 2)) for _ in range(100)]
function3.compile(inputset)
function3.compile(inputset, configuration)
helpers.check_str(
# pylint: disable=line-too-long

View File

@@ -116,18 +116,18 @@ def test_reshape(shape, newshape, helpers):
configuration = helpers.configuration()
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def function(x):
return np.reshape(x, newshape)
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def method(x):
return x.reshape(newshape)
inputset = [np.random.randint(0, 2 ** 5, size=shape) for i in range(100)]
function_circuit = function.compile(inputset)
method_circuit = method.compile(inputset)
function_circuit = function.compile(inputset, configuration)
method_circuit = method.compile(inputset, configuration)
sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8)
@@ -159,12 +159,12 @@ def test_flatten(shape, helpers):
configuration = helpers.configuration()
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
@cnp.compiler({"x": "encrypted"})
def function(x):
return x.flatten()
inputset = [np.random.randint(0, 2 ** 5, size=shape) for i in range(100)]
circuit = function.compile(inputset)
circuit = function.compile(inputset, configuration)
sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8)
helpers.check_execution(circuit, function, sample)

View File

@@ -154,10 +154,10 @@ def test_static_indexing(shape, function, helpers):
"""
configuration = helpers.configuration()
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = [np.random.randint(0, 2 ** 5, size=shape) for _ in range(100)]
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8)
helpers.check_execution(circuit, function, sample)
@@ -173,21 +173,21 @@ def test_bad_static_indexing(helpers):
# with float
# ----------
compiler = cnp.Compiler(lambda x: x[1.5], {"x": "encrypted"}, configuration)
compiler = cnp.Compiler(lambda x: x[1.5], {"x": "encrypted"})
inputset = [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(100)]
with pytest.raises(ValueError) as excinfo:
compiler.compile(inputset)
compiler.compile(inputset, configuration)
assert str(excinfo.value) == "Indexing with '1.5' is not supported"
# with bad slice
# --------------
compiler = cnp.Compiler(lambda x: x[slice(1.5, 2.5, None)], {"x": "encrypted"}, configuration)
compiler = cnp.Compiler(lambda x: x[slice(1.5, 2.5, None)], {"x": "encrypted"})
inputset = [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(100)]
with pytest.raises(ValueError) as excinfo:
compiler.compile(inputset)
compiler.compile(inputset, configuration)
assert str(excinfo.value) == "Indexing with '1.5:2.5' is not supported"

View File

@@ -47,10 +47,10 @@ def test_constant_sub(function, parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -105,10 +105,10 @@ def test_sum(function, parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -39,10 +39,10 @@ def test_transpose(function, parameters, helpers):
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -389,10 +389,10 @@ def test_graph_converter_bad_convert(
"""
configuration = helpers.configuration()
compiler = cnp.Compiler(function, encryption_statuses, configuration)
compiler = cnp.Compiler(function, encryption_statuses)
with pytest.raises(expected_error) as excinfo:
compiler.compile(inputset)
compiler.compile(inputset, configuration)
helpers.check_str(expected_message, str(excinfo.value))

View File

@@ -39,8 +39,8 @@ def test_graph_maximum_integer_bit_width(function, inputset, expected_result, he
configuration = helpers.configuration()
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration=configuration)
graph = compiler.trace(inputset)
compiler = cnp.Compiler(function, {"x": "encrypted"})
graph = compiler.trace(inputset, configuration)
print(graph.format())