mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: move configuration and artifacts to compile and trace methods
This commit is contained in:
@@ -50,14 +50,13 @@ class Circuit:
|
||||
graph: Graph,
|
||||
mlir: str,
|
||||
configuration: Optional[Configuration] = None,
|
||||
virtual: bool = False,
|
||||
):
|
||||
configuration = configuration if configuration is not None else Configuration()
|
||||
|
||||
self.graph = graph
|
||||
self.mlir = mlir
|
||||
|
||||
self.virtual = virtual
|
||||
self.virtual = configuration.virtual
|
||||
if self.virtual:
|
||||
return
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ Declaration of `Compiler` class.
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -48,8 +49,6 @@ class Compiler:
|
||||
self,
|
||||
function: Callable,
|
||||
parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]],
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
@@ -82,16 +81,12 @@ class Compiler:
|
||||
for param, status in parameter_encryption_statuses.items()
|
||||
}
|
||||
|
||||
self.configuration = configuration if configuration is not None else Configuration()
|
||||
self.artifacts = artifacts if artifacts is not None else DebugArtifacts()
|
||||
self.configuration = Configuration()
|
||||
self.artifacts = DebugArtifacts()
|
||||
|
||||
self.inputset = []
|
||||
self.graph = None
|
||||
|
||||
self.artifacts.add_source_code(function)
|
||||
for param, encryption_status in parameter_encryption_statuses.items():
|
||||
self.artifacts.add_parameter_encryption_status(param, encryption_status)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -126,6 +121,10 @@ class Compiler:
|
||||
sample to use for tracing
|
||||
"""
|
||||
|
||||
self.artifacts.add_source_code(self.function)
|
||||
for param, encryption_status in self.parameter_encryption_statuses.items():
|
||||
self.artifacts.add_parameter_encryption_status(param, encryption_status)
|
||||
|
||||
parameters = {
|
||||
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
|
||||
for arg, (param, status) in zip(
|
||||
@@ -180,7 +179,9 @@ class Compiler:
|
||||
def trace(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Graph:
|
||||
"""
|
||||
Trace the function using an inputset.
|
||||
@@ -189,19 +190,37 @@ class Compiler:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
show_graph (bool, default = False):
|
||||
whether to print the computation graph
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Graph:
|
||||
computation graph representing the function prior to MLIR conversion
|
||||
"""
|
||||
|
||||
old_configuration = deepcopy(self.configuration)
|
||||
old_artifacts = deepcopy(self.artifacts)
|
||||
|
||||
if configuration is not None:
|
||||
self.configuration = configuration
|
||||
if artifacts is not None:
|
||||
self.artifacts = artifacts
|
||||
|
||||
if len(kwargs) != 0:
|
||||
self.configuration = self.configuration.fork(**kwargs)
|
||||
|
||||
try:
|
||||
|
||||
self._evaluate("Tracing", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
if show_graph:
|
||||
if self.configuration.verbose or self.configuration.show_graph:
|
||||
graph = self.graph.format()
|
||||
longest_line = max([len(line) for line in graph.split("\n")])
|
||||
|
||||
@@ -247,12 +266,19 @@ class Compiler:
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
||||
self.configuration = old_configuration
|
||||
self.artifacts = old_artifacts
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
|
||||
def compile(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
show_mlir: bool = False,
|
||||
virtual: bool = False,
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Compile the function using an inputset.
|
||||
@@ -261,34 +287,50 @@ class Compiler:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
show_graph (bool, default = False):
|
||||
whether to print the computation graph
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
show_mlir (bool, default = False):
|
||||
whether to print the compiled mlir
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
virtual (bool, default = False):
|
||||
whether to simulate the computation to allow large bit-widths
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
compiled circuit
|
||||
"""
|
||||
|
||||
old_configuration = deepcopy(self.configuration)
|
||||
old_artifacts = deepcopy(self.artifacts)
|
||||
|
||||
if configuration is not None:
|
||||
self.configuration = configuration
|
||||
if artifacts is not None:
|
||||
self.artifacts = artifacts
|
||||
|
||||
if len(kwargs) != 0:
|
||||
self.configuration = self.configuration.fork(**kwargs)
|
||||
|
||||
try:
|
||||
if virtual and not self.configuration.enable_unsafe_features:
|
||||
raise RuntimeError(
|
||||
"Virtual compilation is not allowed without enabling unsafe features"
|
||||
)
|
||||
|
||||
self._evaluate("Compiling", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
mlir = GraphConverter.convert(self.graph, virtual=virtual)
|
||||
mlir = GraphConverter.convert(self.graph, virtual=self.configuration.virtual)
|
||||
self.artifacts.add_mlir_to_compile(mlir)
|
||||
|
||||
if show_graph or show_mlir:
|
||||
graph = self.graph.format() if show_graph else ""
|
||||
if (
|
||||
self.configuration.verbose
|
||||
or self.configuration.show_graph
|
||||
or self.configuration.show_mlir
|
||||
):
|
||||
|
||||
graph = (
|
||||
self.graph.format()
|
||||
if self.configuration.verbose or self.configuration.show_graph
|
||||
else ""
|
||||
)
|
||||
|
||||
longest_graph_line = max([len(line) for line in graph.split("\n")])
|
||||
longest_mlir_line = max([len(line) for line in mlir.split("\n")])
|
||||
@@ -308,7 +350,7 @@ class Compiler:
|
||||
except OSError: # pragma: no cover
|
||||
columns = min(longest_line, 80)
|
||||
|
||||
if show_graph:
|
||||
if self.configuration.verbose or self.configuration.show_graph:
|
||||
print()
|
||||
|
||||
print("Computation Graph")
|
||||
@@ -318,8 +360,13 @@ class Compiler:
|
||||
|
||||
print()
|
||||
|
||||
if show_mlir:
|
||||
print("\n" if not show_graph else "", end="")
|
||||
if self.configuration.verbose or self.configuration.show_mlir:
|
||||
print(
|
||||
"\n"
|
||||
if not (self.configuration.verbose or self.configuration.show_graph)
|
||||
else "",
|
||||
end="",
|
||||
)
|
||||
|
||||
print("MLIR")
|
||||
print("-" * columns)
|
||||
@@ -328,7 +375,7 @@ class Compiler:
|
||||
|
||||
print()
|
||||
|
||||
return Circuit(self.graph, mlir, self.configuration, virtual=virtual)
|
||||
return Circuit(self.graph, mlir, self.configuration)
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
@@ -346,3 +393,10 @@ class Compiler:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
||||
self.configuration = old_configuration
|
||||
self.artifacts = old_artifacts
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-statements
|
||||
|
||||
@@ -13,31 +13,63 @@ class Configuration:
|
||||
Configuration class, to allow the compilation process to be customized.
|
||||
"""
|
||||
|
||||
verbose: bool
|
||||
show_graph: bool
|
||||
show_mlir: bool
|
||||
dump_artifacts_on_unexpected_failures: bool
|
||||
enable_unsafe_features: bool
|
||||
virtual: bool
|
||||
use_insecure_key_cache: bool
|
||||
loop_parallelize: bool
|
||||
dataflow_parallelize: bool
|
||||
auto_parallelize: bool
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Validate configuration.
|
||||
"""
|
||||
|
||||
if not self.enable_unsafe_features:
|
||||
|
||||
if self.use_insecure_key_cache:
|
||||
raise RuntimeError(
|
||||
"Insecure key cache cannot be used without enabling unsafe features"
|
||||
)
|
||||
|
||||
if self.virtual:
|
||||
raise RuntimeError(
|
||||
"Virtual compilation is not allowed without enabling unsafe features"
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose: bool = False,
|
||||
show_graph: bool = False,
|
||||
show_mlir: bool = False,
|
||||
dump_artifacts_on_unexpected_failures: bool = True,
|
||||
enable_unsafe_features: bool = False,
|
||||
virtual: bool = False,
|
||||
use_insecure_key_cache: bool = False,
|
||||
loop_parallelize: bool = True,
|
||||
dataflow_parallelize: bool = False,
|
||||
auto_parallelize: bool = False,
|
||||
):
|
||||
self.verbose = verbose
|
||||
self.show_graph = show_graph
|
||||
self.show_mlir = show_mlir
|
||||
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
|
||||
self.enable_unsafe_features = enable_unsafe_features
|
||||
self.virtual = virtual
|
||||
self.use_insecure_key_cache = use_insecure_key_cache
|
||||
self.loop_parallelize = loop_parallelize
|
||||
self.dataflow_parallelize = dataflow_parallelize
|
||||
self.auto_parallelize = auto_parallelize
|
||||
|
||||
if not enable_unsafe_features and use_insecure_key_cache:
|
||||
raise RuntimeError("Insecure key cache cannot be used without enabling unsafe features")
|
||||
self._validate()
|
||||
|
||||
# pylint: enable=too-many-arguments
|
||||
|
||||
@staticmethod
|
||||
def insecure_key_cache_location() -> Optional[str]:
|
||||
@@ -80,4 +112,8 @@ class Configuration:
|
||||
|
||||
setattr(result, name, value)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
result._validate()
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return result
|
||||
|
||||
@@ -11,23 +11,13 @@ from .compiler import Compiler, EncryptionStatus
|
||||
from .configuration import Configuration
|
||||
|
||||
|
||||
def compiler(
|
||||
parameters: Mapping[str, EncryptionStatus],
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
):
|
||||
def compiler(parameters: Mapping[str, EncryptionStatus]):
|
||||
"""
|
||||
Provide an easy interface for compilation.
|
||||
|
||||
Args:
|
||||
parameters (Dict[str, EncryptionStatus]):
|
||||
encryption statuses of the parameters of the function to compile
|
||||
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use for compilation
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about compilation
|
||||
"""
|
||||
|
||||
def decoration(function: Callable):
|
||||
@@ -41,12 +31,7 @@ def compiler(
|
||||
|
||||
def __init__(self, function: Callable):
|
||||
self.function = function # type: ignore
|
||||
self.compiler = Compiler(
|
||||
self.function,
|
||||
dict(parameters),
|
||||
configuration,
|
||||
artifacts,
|
||||
)
|
||||
self.compiler = Compiler(self.function, dict(parameters))
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
self.compiler(*args, **kwargs)
|
||||
@@ -55,7 +40,9 @@ def compiler(
|
||||
def trace(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Graph:
|
||||
"""
|
||||
Trace the function into computation graph.
|
||||
@@ -64,22 +51,28 @@ def compiler(
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
show_graph (bool, default = False):
|
||||
whether to print the computation graph
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Graph:
|
||||
computation graph representing the function prior to MLIR conversion
|
||||
"""
|
||||
|
||||
return self.compiler.trace(inputset, show_graph)
|
||||
return self.compiler.trace(inputset, configuration, artifacts, **kwargs)
|
||||
|
||||
def compile(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
show_mlir: bool = False,
|
||||
virtual: bool = False,
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Compile the function into a circuit.
|
||||
@@ -88,21 +81,21 @@ def compiler(
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
show_graph (bool, default = False):
|
||||
whether to print the computation graph
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
show_mlir (bool, default = False):
|
||||
whether to print the compiled mlir
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
virtual (bool, default = False):
|
||||
whether to simulate the computation to allow large bit-widths
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
compiled circuit
|
||||
"""
|
||||
|
||||
return self.compiler.compile(inputset, show_graph, show_mlir, virtual)
|
||||
return self.compiler.compile(inputset, configuration, artifacts, **kwargs)
|
||||
|
||||
return Compilable(function)
|
||||
|
||||
|
||||
@@ -121,11 +121,11 @@ import pathlib
|
||||
|
||||
artifacts = cnp.DebugArtifacts("/tmp/custom/export/path")
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, artifacts=artifacts)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
return 127 - (50 * (np.sin(x) + 1)).astype(np.int64)
|
||||
|
||||
f.compile(range(2 ** 3))
|
||||
f.compile(range(2 ** 3), artifacts=artifacts)
|
||||
|
||||
artifacts.export()
|
||||
```
|
||||
|
||||
@@ -19,12 +19,12 @@ def test_artifacts_export(helpers):
|
||||
configuration = helpers.configuration()
|
||||
artifacts = DebugArtifacts(tmpdir)
|
||||
|
||||
@compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts)
|
||||
@compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
return x + 10
|
||||
|
||||
inputset = range(100)
|
||||
f.compile(inputset)
|
||||
f.compile(inputset, configuration, artifacts)
|
||||
|
||||
artifacts.export()
|
||||
|
||||
|
||||
@@ -18,12 +18,12 @@ def test_circuit_str(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset)
|
||||
circuit = f.compile(inputset, configuration)
|
||||
|
||||
assert str(circuit) == (
|
||||
"""
|
||||
@@ -44,12 +44,12 @@ def test_circuit_draw(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset)
|
||||
circuit = f.compile(inputset, configuration)
|
||||
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
tmpdir = Path(path)
|
||||
@@ -67,12 +67,12 @@ def test_circuit_bad_run(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset)
|
||||
circuit = f.compile(inputset, configuration)
|
||||
|
||||
# with 1 argument
|
||||
# ---------------
|
||||
@@ -138,12 +138,12 @@ def test_circuit_virtual_explicit_api(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
inputset = [(np.random.randint(0, 2 ** 4), np.random.randint(0, 2 ** 5)) for _ in range(100)]
|
||||
circuit = f.compile(inputset, virtual=True)
|
||||
circuit = f.compile(inputset, configuration, virtual=True)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
circuit.keygen()
|
||||
|
||||
@@ -7,13 +7,11 @@ import pytest
|
||||
from concrete.numpy.compilation import Compiler
|
||||
|
||||
|
||||
def test_compiler_bad_init(helpers):
|
||||
def test_compiler_bad_init():
|
||||
"""
|
||||
Test `__init__` method of `Compiler` class with bad parameters.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
def f(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
@@ -21,7 +19,7 @@ def test_compiler_bad_init(helpers):
|
||||
# -----------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
Compiler(f, {}, configuration=configuration)
|
||||
Compiler(f, {})
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Encryption statuses of parameters 'x', 'y' and 'z' of function 'f' are not provided"
|
||||
@@ -31,7 +29,7 @@ def test_compiler_bad_init(helpers):
|
||||
# ---------------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
Compiler(f, {"z": "clear"}, configuration=configuration)
|
||||
Compiler(f, {"z": "clear"})
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Encryption statuses of parameters 'x' and 'y' of function 'f' are not provided"
|
||||
@@ -41,7 +39,7 @@ def test_compiler_bad_init(helpers):
|
||||
# ---------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
Compiler(f, {"y": "encrypted", "z": "clear"}, configuration=configuration)
|
||||
Compiler(f, {"y": "encrypted", "z": "clear"})
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Encryption status of parameter 'x' of function 'f' is not provided"
|
||||
@@ -52,29 +50,19 @@ def test_compiler_bad_init(helpers):
|
||||
|
||||
# this is fine and `p` is just ignored
|
||||
|
||||
Compiler(
|
||||
f,
|
||||
{"x": "encrypted", "y": "encrypted", "z": "clear", "p": "clear"},
|
||||
configuration=configuration,
|
||||
)
|
||||
Compiler(f, {"x": "encrypted", "y": "encrypted", "z": "clear", "p": "clear"})
|
||||
|
||||
|
||||
def test_compiler_bad_call(helpers):
|
||||
def test_compiler_bad_call():
|
||||
"""
|
||||
Test `__call__` method of `Compiler` class with bad parameters.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
def f(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compiler = Compiler(
|
||||
f,
|
||||
{"x": "encrypted", "y": "encrypted", "z": "clear"},
|
||||
configuration=configuration,
|
||||
)
|
||||
compiler = Compiler(f, {"x": "encrypted", "y": "encrypted", "z": "clear"})
|
||||
compiler(1, 2, 3, invalid=4)
|
||||
|
||||
assert str(excinfo.value) == "Calling function 'f' with kwargs is not supported"
|
||||
@@ -94,9 +82,8 @@ def test_compiler_bad_trace(helpers):
|
||||
compiler = Compiler(
|
||||
f,
|
||||
{"x": "encrypted", "y": "encrypted", "z": "clear"},
|
||||
configuration=configuration,
|
||||
)
|
||||
compiler.trace()
|
||||
compiler.trace(configuration=configuration)
|
||||
|
||||
assert str(excinfo.value) == "Tracing function 'f' without an inputset is not supported"
|
||||
|
||||
@@ -115,17 +102,18 @@ def test_compiler_bad_compile(helpers):
|
||||
compiler = Compiler(
|
||||
f,
|
||||
{"x": "encrypted", "y": "encrypted", "z": "clear"},
|
||||
configuration=configuration,
|
||||
)
|
||||
compiler.compile()
|
||||
compiler.compile(configuration=configuration)
|
||||
|
||||
assert str(excinfo.value) == "Compiling function 'f' without an inputset is not supported"
|
||||
|
||||
configuration.enable_unsafe_features = False
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compiler = Compiler(lambda x: x, {"x": "encrypted"}, configuration=configuration)
|
||||
compiler.compile(virtual=True)
|
||||
compiler = Compiler(lambda x: x, {"x": "encrypted"})
|
||||
compiler.compile(
|
||||
range(10),
|
||||
configuration.fork(enable_unsafe_features=False, use_insecure_key_cache=False),
|
||||
virtual=True,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Virtual compilation is not allowed without enabling unsafe features"
|
||||
@@ -142,7 +130,7 @@ def test_compiler_virtual_compile(helpers):
|
||||
def f(x):
|
||||
return x + 400
|
||||
|
||||
compiler = Compiler(f, {"x": "encrypted"}, configuration=configuration)
|
||||
circuit = compiler.compile(inputset=range(400), virtual=True)
|
||||
compiler = Compiler(f, {"x": "encrypted"})
|
||||
circuit = compiler.compile(inputset=range(400), configuration=configuration, virtual=True)
|
||||
|
||||
assert circuit.encrypt_run_decrypt(200) == 600
|
||||
|
||||
@@ -12,14 +12,14 @@ def test_call_compile(helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
for i in range(10):
|
||||
function(i)
|
||||
|
||||
circuit = function.compile()
|
||||
circuit = function.compile(configuration=configuration)
|
||||
|
||||
sample = 5
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -33,12 +33,12 @@ def test_compiler_verbose_trace(helpers, capsys):
|
||||
configuration = helpers.configuration()
|
||||
artifacts = DebugArtifacts()
|
||||
|
||||
@compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts)
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.trace(inputset, show_graph=True)
|
||||
function.trace(inputset, configuration, artifacts, show_graph=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
@@ -61,12 +61,12 @@ def test_compiler_verbose_compile(helpers, capsys):
|
||||
configuration = helpers.configuration()
|
||||
artifacts = DebugArtifacts()
|
||||
|
||||
@compiler({"x": "encrypted"}, configuration=configuration, artifacts=artifacts)
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.compile(inputset, show_graph=True, show_mlir=True)
|
||||
function.compile(inputset, configuration, artifacts, show_graph=True, show_mlir=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
|
||||
@@ -59,10 +59,10 @@ def test_constant_add(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -150,10 +150,10 @@ def test_add(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -167,10 +167,10 @@ def test_concatenate(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -57,12 +57,12 @@ def test_conv2d(input_shape, weight_shape, strides, dilations, has_bias, helpers
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return connx.conv(x, weight, bias, strides=strides, dilations=dilations)
|
||||
|
||||
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
|
||||
circuit = function.compile(inputset)
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 4, size=input_shape, dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -373,7 +373,7 @@ def test_bad_conv_compilation(
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return connx.conv(
|
||||
x,
|
||||
@@ -389,7 +389,7 @@ def test_bad_conv_compilation(
|
||||
|
||||
inputset = [np.random.randint(0, 4, size=input_shape) for i in range(100)]
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
function.compile(inputset)
|
||||
function.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == expected_message
|
||||
|
||||
|
||||
@@ -166,10 +166,10 @@ def test_direct_table_lookup(bits, function, helpers):
|
||||
# scalar
|
||||
# ------
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = range(2 ** bits)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = int(np.random.randint(0, 2 ** bits))
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -177,10 +177,10 @@ def test_direct_table_lookup(bits, function, helpers):
|
||||
# tensor
|
||||
# ------
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** bits, size=(3, 2), dtype=np.uint8) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** bits, size=(3, 2), dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -207,10 +207,10 @@ def test_direct_multi_table_lookup(helpers):
|
||||
def function(x):
|
||||
return table[x]
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 2, size=(3, 2), dtype=np.uint8) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** 2, size=(3, 2), dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -277,22 +277,22 @@ def test_bad_direct_table_lookup(helpers):
|
||||
# compilation with float value
|
||||
# ----------------------------
|
||||
|
||||
compiler = cnp.Compiler(random_table_lookup_3b, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(random_table_lookup_3b, {"x": "encrypted"})
|
||||
|
||||
inputset = [1.5]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == "LookupTable cannot be looked up with EncryptedScalar<float64>"
|
||||
|
||||
# compilation with invalid shape
|
||||
# ------------------------------
|
||||
|
||||
compiler = cnp.Compiler(lambda x: table[x], {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(lambda x: table[x], {"x": "encrypted"})
|
||||
|
||||
inputset = [10, 5, 6, 2]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"LookupTable of shape (3, 2) cannot be looked up with EncryptedScalar<uint4>"
|
||||
|
||||
@@ -22,23 +22,23 @@ def test_dot(size, helpers):
|
||||
bound = int(np.floor(np.sqrt(127 / size)))
|
||||
cst = np.random.randint(0, bound, size=(size,))
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def left_function(x):
|
||||
return np.dot(x, cst)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def right_function(x):
|
||||
return np.dot(cst, x)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def method(x):
|
||||
return x.dot(cst)
|
||||
|
||||
inputset = [np.random.randint(0, bound, size=(size,)) for i in range(100)]
|
||||
|
||||
left_function_circuit = left_function.compile(inputset)
|
||||
right_function_circuit = right_function.compile(inputset)
|
||||
method_circuit = method.compile(inputset)
|
||||
left_function_circuit = left_function.compile(inputset, configuration)
|
||||
right_function_circuit = right_function.compile(inputset, configuration)
|
||||
method_circuit = method.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, bound, size=(size,), dtype=np.uint8)
|
||||
|
||||
|
||||
@@ -125,29 +125,29 @@ def test_matmul(lhs_shape, rhs_shape, bounds, helpers):
|
||||
lhs_cst = list(np.random.randint(minimum, maximum, size=lhs_shape))
|
||||
rhs_cst = list(np.random.randint(minimum, maximum, size=rhs_shape))
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def lhs_operator(x):
|
||||
return x @ rhs_cst
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def rhs_operator(x):
|
||||
return lhs_cst @ x
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def lhs_function(x):
|
||||
return np.matmul(x, rhs_cst)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def rhs_function(x):
|
||||
return np.matmul(lhs_cst, x)
|
||||
|
||||
lhs_inputset = [np.random.randint(minimum, maximum, size=lhs_shape) for i in range(100)]
|
||||
rhs_inputset = [np.random.randint(minimum, maximum, size=rhs_shape) for i in range(100)]
|
||||
|
||||
lhs_operator_circuit = lhs_operator.compile(lhs_inputset)
|
||||
rhs_operator_circuit = rhs_operator.compile(rhs_inputset)
|
||||
lhs_function_circuit = lhs_function.compile(lhs_inputset)
|
||||
rhs_function_circuit = rhs_function.compile(rhs_inputset)
|
||||
lhs_operator_circuit = lhs_operator.compile(lhs_inputset, configuration)
|
||||
rhs_operator_circuit = rhs_operator.compile(rhs_inputset, configuration)
|
||||
lhs_function_circuit = lhs_function.compile(lhs_inputset, configuration)
|
||||
rhs_function_circuit = rhs_function.compile(rhs_inputset, configuration)
|
||||
|
||||
lhs_sample = np.random.randint(minimum, maximum, size=lhs_shape, dtype=np.uint8)
|
||||
rhs_sample = np.random.randint(minimum, maximum, size=rhs_shape, dtype=np.uint8)
|
||||
|
||||
@@ -67,10 +67,10 @@ def test_constant_mul(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -27,18 +27,18 @@ def test_neg(parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@cnp.compiler(parameter_encryption_statuses, configuration=configuration)
|
||||
@cnp.compiler(parameter_encryption_statuses)
|
||||
def operator(x):
|
||||
return -x
|
||||
|
||||
@cnp.compiler(parameter_encryption_statuses, configuration=configuration)
|
||||
@cnp.compiler(parameter_encryption_statuses)
|
||||
def function(x):
|
||||
return np.negative(x)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
|
||||
operator_circuit = operator.compile(inputset)
|
||||
function_circuit = function.compile(inputset)
|
||||
operator_circuit = operator.compile(inputset, configuration)
|
||||
function_circuit = function.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
|
||||
|
||||
@@ -447,10 +447,10 @@ def test_others(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -464,10 +464,10 @@ def test_others(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample, retries=10)
|
||||
@@ -483,13 +483,13 @@ def test_others_bad_fusing(helpers):
|
||||
# two variable inputs
|
||||
# -------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted", "y": "clear"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted", "y": "clear"})
|
||||
def function1(x, y):
|
||||
return (10 * (np.sin(x) ** 2) + 10 * (np.cos(y) ** 2)).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = [(i, i) for i in range(100)]
|
||||
function1.compile(inputset)
|
||||
function1.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
@@ -528,13 +528,13 @@ return %13
|
||||
# big intermediate constants
|
||||
# --------------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function2(x):
|
||||
return (np.sin(x) * [[1, 2], [3, 4]]).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = range(100)
|
||||
function2.compile(inputset)
|
||||
function2.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
@@ -559,13 +559,13 @@ return %4
|
||||
# intermediates with different shape
|
||||
# ----------------------------------
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function3(x):
|
||||
return np.abs(np.sin(x)).reshape((2, 3)).astype(np.int64)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
inputset = [np.random.randint(0, 2 ** 7, size=(3, 2)) for _ in range(100)]
|
||||
function3.compile(inputset)
|
||||
function3.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
@@ -116,18 +116,18 @@ def test_reshape(shape, newshape, helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return np.reshape(x, newshape)
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def method(x):
|
||||
return x.reshape(newshape)
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 5, size=shape) for i in range(100)]
|
||||
|
||||
function_circuit = function.compile(inputset)
|
||||
method_circuit = method.compile(inputset)
|
||||
function_circuit = function.compile(inputset, configuration)
|
||||
method_circuit = method.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8)
|
||||
|
||||
@@ -159,12 +159,12 @@ def test_flatten(shape, helpers):
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@cnp.compiler({"x": "encrypted"}, configuration=configuration)
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x.flatten()
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 5, size=shape) for i in range(100)]
|
||||
circuit = function.compile(inputset)
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -154,10 +154,10 @@ def test_static_indexing(shape, function, helpers):
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 5, size=shape) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2 ** 5, size=shape, dtype=np.uint8)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
@@ -173,21 +173,21 @@ def test_bad_static_indexing(helpers):
|
||||
# with float
|
||||
# ----------
|
||||
|
||||
compiler = cnp.Compiler(lambda x: x[1.5], {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(lambda x: x[1.5], {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(100)]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == "Indexing with '1.5' is not supported"
|
||||
|
||||
# with bad slice
|
||||
# --------------
|
||||
|
||||
compiler = cnp.Compiler(lambda x: x[slice(1.5, 2.5, None)], {"x": "encrypted"}, configuration)
|
||||
compiler = cnp.Compiler(lambda x: x[slice(1.5, 2.5, None)], {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(100)]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == "Indexing with '1.5:2.5' is not supported"
|
||||
|
||||
@@ -47,10 +47,10 @@ def test_constant_sub(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -105,10 +105,10 @@ def test_sum(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -39,10 +39,10 @@ def test_transpose(function, parameters, helpers):
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
@@ -389,10 +389,10 @@ def test_graph_converter_bad_convert(
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = cnp.Compiler(function, encryption_statuses, configuration)
|
||||
compiler = cnp.Compiler(function, encryption_statuses)
|
||||
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
compiler.compile(inputset)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
helpers.check_str(expected_message, str(excinfo.value))
|
||||
|
||||
|
||||
@@ -39,8 +39,8 @@ def test_graph_maximum_integer_bit_width(function, inputset, expected_result, he
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"}, configuration=configuration)
|
||||
graph = compiler.trace(inputset)
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
graph = compiler.trace(inputset, configuration)
|
||||
|
||||
print(graph.format())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user