mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: introduce circuit decorator to directly define circuits
This commit is contained in:
@@ -13,8 +13,141 @@ from .compilation import (
|
||||
DebugArtifacts,
|
||||
EncryptionStatus,
|
||||
Server,
|
||||
compiler,
|
||||
)
|
||||
from .compilation.decorators import circuit, compiler
|
||||
from .extensions import LookupTable, array, one, ones, univariate, zero, zeros
|
||||
from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH
|
||||
from .representation import Graph
|
||||
from .tracing.typing import (
|
||||
f32,
|
||||
f64,
|
||||
int1,
|
||||
int2,
|
||||
int3,
|
||||
int4,
|
||||
int5,
|
||||
int6,
|
||||
int7,
|
||||
int8,
|
||||
int9,
|
||||
int10,
|
||||
int11,
|
||||
int12,
|
||||
int13,
|
||||
int14,
|
||||
int15,
|
||||
int16,
|
||||
int17,
|
||||
int18,
|
||||
int19,
|
||||
int20,
|
||||
int21,
|
||||
int22,
|
||||
int23,
|
||||
int24,
|
||||
int25,
|
||||
int26,
|
||||
int27,
|
||||
int28,
|
||||
int29,
|
||||
int30,
|
||||
int31,
|
||||
int32,
|
||||
int33,
|
||||
int34,
|
||||
int35,
|
||||
int36,
|
||||
int37,
|
||||
int38,
|
||||
int39,
|
||||
int40,
|
||||
int41,
|
||||
int42,
|
||||
int43,
|
||||
int44,
|
||||
int45,
|
||||
int46,
|
||||
int47,
|
||||
int48,
|
||||
int49,
|
||||
int50,
|
||||
int51,
|
||||
int52,
|
||||
int53,
|
||||
int54,
|
||||
int55,
|
||||
int56,
|
||||
int57,
|
||||
int58,
|
||||
int59,
|
||||
int60,
|
||||
int61,
|
||||
int62,
|
||||
int63,
|
||||
int64,
|
||||
tensor,
|
||||
uint1,
|
||||
uint2,
|
||||
uint3,
|
||||
uint4,
|
||||
uint5,
|
||||
uint6,
|
||||
uint7,
|
||||
uint8,
|
||||
uint9,
|
||||
uint10,
|
||||
uint11,
|
||||
uint12,
|
||||
uint13,
|
||||
uint14,
|
||||
uint15,
|
||||
uint16,
|
||||
uint17,
|
||||
uint18,
|
||||
uint19,
|
||||
uint20,
|
||||
uint21,
|
||||
uint22,
|
||||
uint23,
|
||||
uint24,
|
||||
uint25,
|
||||
uint26,
|
||||
uint27,
|
||||
uint28,
|
||||
uint29,
|
||||
uint30,
|
||||
uint31,
|
||||
uint32,
|
||||
uint33,
|
||||
uint34,
|
||||
uint35,
|
||||
uint36,
|
||||
uint37,
|
||||
uint38,
|
||||
uint39,
|
||||
uint40,
|
||||
uint41,
|
||||
uint42,
|
||||
uint43,
|
||||
uint44,
|
||||
uint45,
|
||||
uint46,
|
||||
uint47,
|
||||
uint48,
|
||||
uint49,
|
||||
uint50,
|
||||
uint51,
|
||||
uint52,
|
||||
uint53,
|
||||
uint54,
|
||||
uint55,
|
||||
uint56,
|
||||
uint57,
|
||||
uint58,
|
||||
uint59,
|
||||
uint60,
|
||||
uint61,
|
||||
uint62,
|
||||
uint63,
|
||||
uint64,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,5 @@ from .circuit import Circuit
|
||||
from .client import Client
|
||||
from .compiler import Compiler, EncryptionStatus
|
||||
from .configuration import Configuration
|
||||
from .decorator import compiler
|
||||
from .server import Server
|
||||
from .specs import ClientSpecs
|
||||
|
||||
@@ -45,6 +45,56 @@ class Compiler:
|
||||
inputset: List[Any]
|
||||
graph: Optional[Graph]
|
||||
|
||||
_is_direct: bool
|
||||
_parameter_values: Dict[str, Value]
|
||||
|
||||
@staticmethod
|
||||
def assemble(
|
||||
function: Callable,
|
||||
parameter_values: Dict[str, Value],
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Assemble a circuit from the raw parameter values, used in direct circuit definition.
|
||||
|
||||
Args:
|
||||
function (Callable):
|
||||
function to convert to a circuit
|
||||
|
||||
parameter_values (Dict[str, Value]):
|
||||
parameter values of the function
|
||||
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
assembled circuit
|
||||
"""
|
||||
|
||||
compiler = Compiler(
|
||||
function,
|
||||
{
|
||||
name: "encrypted" if value.is_encrypted else "clear"
|
||||
for name, value in parameter_values.items()
|
||||
},
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
compiler._is_direct = True
|
||||
compiler._parameter_values = parameter_values
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return compiler.compile(None, configuration, artifacts, **kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: Callable,
|
||||
@@ -102,6 +152,9 @@ class Compiler:
|
||||
self.inputset = []
|
||||
self.graph = None
|
||||
|
||||
self._is_direct = False
|
||||
self._parameter_values = {}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -171,6 +224,18 @@ class Compiler:
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
"""
|
||||
|
||||
if self._is_direct:
|
||||
|
||||
self.graph = Tracer.trace(self.function, self._parameter_values, is_direct=True)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("initial", self.graph) # pragma: no cover
|
||||
|
||||
fuse(self.graph, self.artifacts)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("final", self.graph) # pragma: no cover
|
||||
|
||||
return
|
||||
|
||||
if inputset is not None:
|
||||
previous_inputset_length = len(self.inputset)
|
||||
for index, sample in enumerate(iter(inputset)):
|
||||
|
||||
@@ -1,22 +1,82 @@
|
||||
"""
|
||||
Declaration of `compiler` decorator.
|
||||
Declaration of `circuit` and `compiler` decorators.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Iterable, Mapping, Optional, Tuple, Union
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
|
||||
|
||||
from ..representation import Graph
|
||||
from ..tracing.typing import ScalarAnnotation
|
||||
from ..values import Value
|
||||
from .artifacts import DebugArtifacts
|
||||
from .circuit import Circuit
|
||||
from .compiler import Compiler, EncryptionStatus
|
||||
from .configuration import Configuration
|
||||
|
||||
|
||||
def compiler(parameters: Mapping[str, EncryptionStatus]):
|
||||
def circuit(
|
||||
parameters: Mapping[str, Union[str, EncryptionStatus]],
|
||||
configuration: Optional[Configuration] = None,
|
||||
artifacts: Optional[DebugArtifacts] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Provide a direct interface for compilation.
|
||||
|
||||
Args:
|
||||
parameters (Mapping[str, Union[str, EncryptionStatus]]):
|
||||
encryption statuses of the parameters of the function to compile
|
||||
|
||||
configuration(Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
artifacts (Optional[DebugArtifacts], default = None):
|
||||
artifacts to store information about the process
|
||||
|
||||
kwargs (Dict[str, Any]):
|
||||
configuration options to overwrite
|
||||
"""
|
||||
|
||||
def decoration(function: Callable):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
parameter_values: Dict[str, Value] = {}
|
||||
for name, details in signature.parameters.items():
|
||||
if name not in parameters:
|
||||
continue
|
||||
|
||||
annotation = details.annotation
|
||||
|
||||
is_value = isinstance(annotation, Value)
|
||||
is_scalar_annotation = isinstance(annotation, type) and issubclass(
|
||||
annotation, ScalarAnnotation
|
||||
)
|
||||
|
||||
if not (is_value or is_scalar_annotation):
|
||||
raise ValueError(
|
||||
f"Annotation {annotation} for argument '{name}' is not valid "
|
||||
f"(please use a cnp type such as "
|
||||
f"`cnp.uint4` or 'cnp.tensor[cnp.uint4, 3, 2]')"
|
||||
)
|
||||
|
||||
parameter_values[name] = (
|
||||
annotation if is_value else Value(annotation.dtype, shape=(), is_encrypted=False)
|
||||
)
|
||||
|
||||
status = EncryptionStatus(parameters[name].lower())
|
||||
parameter_values[name].is_encrypted = status == "encrypted"
|
||||
|
||||
return Compiler.assemble(function, parameter_values, configuration, artifacts, **kwargs)
|
||||
|
||||
return decoration
|
||||
|
||||
|
||||
def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]):
|
||||
"""
|
||||
Provide an easy interface for compilation.
|
||||
|
||||
Args:
|
||||
parameters (Dict[str, EncryptionStatus]):
|
||||
parameters (Mapping[str, Union[str, EncryptionStatus]]):
|
||||
encryption statuses of the parameters of the function to compile
|
||||
"""
|
||||
|
||||
@@ -2,18 +2,19 @@
|
||||
Declaration of `univariate` function.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import Float
|
||||
from ..dtypes import BaseDataType, Float
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..tracing import ScalarAnnotation, Tracer
|
||||
from ..values import Value
|
||||
|
||||
|
||||
def univariate(
|
||||
function: Callable[[Any], Any],
|
||||
outputs: Optional[Union[BaseDataType, Type[ScalarAnnotation]]] = None,
|
||||
) -> Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
|
||||
"""
|
||||
Wrap a univariate function so that it is traced into a single generic node.
|
||||
@@ -22,6 +23,9 @@ def univariate(
|
||||
function (Callable[[Any], Any]):
|
||||
univariate function to wrap
|
||||
|
||||
outputs (Optional[Union[BaseDataType, Type[ScalarAnnotation]]], default = None):
|
||||
data type of the result, unused during compilation, required for direct definition
|
||||
|
||||
Returns:
|
||||
Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
|
||||
another univariate function that can be called with a Tracer as well
|
||||
@@ -57,6 +61,19 @@ def univariate(
|
||||
if output_value.shape != x.output.shape:
|
||||
raise ValueError(f"Function {function.__name__} cannot be used with cnp.univariate")
|
||||
|
||||
# pylint: disable=protected-access
|
||||
is_direct = Tracer._is_direct
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if is_direct:
|
||||
if outputs is None:
|
||||
raise ValueError(
|
||||
"Univariate extension requires "
|
||||
"`outputs` argument for direct circuit definition "
|
||||
"(e.g., cnp.univariate(function, outputs=cnp.uint4)(x))"
|
||||
)
|
||||
output_value.dtype = outputs if isinstance(outputs, BaseDataType) else outputs.dtype
|
||||
|
||||
computation = Node.generic(
|
||||
function.__name__,
|
||||
[x.output],
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
Provide `function` to `computation graph` functionality.
|
||||
"""
|
||||
|
||||
from .tracer import Tracer
|
||||
from .tracer import ScalarAnnotation, TensorAnnotation, Tracer
|
||||
|
||||
@@ -4,13 +4,13 @@ Declaration of `Tracer` class.
|
||||
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from ..dtypes import Float, Integer
|
||||
from ..dtypes import BaseDataType, Float, Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..representation.utils import format_indexing_element
|
||||
@@ -29,13 +29,12 @@ class Tracer:
|
||||
# property to keep track of assignments
|
||||
last_version: Optional["Tracer"] = None
|
||||
|
||||
# variable to control the behavior of __eq__
|
||||
# so that it can be traced but still allow
|
||||
# using Tracers in dicts when not tracing
|
||||
# variables to control the behavior of certain functions
|
||||
_is_tracing: bool = False
|
||||
_is_direct: bool = False
|
||||
|
||||
@staticmethod
|
||||
def trace(function: Callable, parameters: Dict[str, Value]) -> Graph:
|
||||
def trace(function: Callable, parameters: Dict[str, Value], is_direct: bool = False) -> Graph:
|
||||
"""
|
||||
Trace `function` and create the `Graph` that represents it.
|
||||
|
||||
@@ -47,6 +46,9 @@ class Tracer:
|
||||
parameters of function to trace
|
||||
e.g. parameter x is an EncryptedScalar holding a 7-bit UnsignedInteger
|
||||
|
||||
is_direct (bool, default = False):
|
||||
whether the tracing is done on actual parameters or placeholders
|
||||
|
||||
Returns:
|
||||
Graph:
|
||||
computation graph corresponding to `function`
|
||||
@@ -67,6 +69,8 @@ class Tracer:
|
||||
arguments[param] = Tracer(node, [])
|
||||
input_indices[node] = index
|
||||
|
||||
Tracer._is_direct = is_direct
|
||||
|
||||
Tracer._is_tracing = True
|
||||
output_tracers: Any = function(**arguments)
|
||||
Tracer._is_tracing = False
|
||||
@@ -383,6 +387,13 @@ class Tracer:
|
||||
output_value = Value.of(evaluation)
|
||||
output_value.is_encrypted = any(tracer.output.is_encrypted for tracer in tracers)
|
||||
|
||||
if Tracer._is_direct and isinstance(output_value.dtype, Integer):
|
||||
resulting_bit_width = 0
|
||||
for tracer in tracers:
|
||||
assert isinstance(tracer.output.dtype, Integer)
|
||||
resulting_bit_width = max(resulting_bit_width, tracer.output.dtype.bit_width)
|
||||
output_value.dtype.bit_width = resulting_bit_width
|
||||
|
||||
computation = Node.generic(
|
||||
operation.__name__,
|
||||
[tracer.output for tracer in tracers],
|
||||
@@ -488,7 +499,12 @@ class Tracer:
|
||||
|
||||
def __round__(self, ndigits=None):
|
||||
if ndigits is None:
|
||||
return Tracer._trace_numpy_operation(np.around, self).astype(np.int64)
|
||||
result = Tracer._trace_numpy_operation(np.around, self)
|
||||
if self._is_direct:
|
||||
raise RuntimeError(
|
||||
"'round(x)' cannot be used in direct definition (you may use np.around instead)"
|
||||
)
|
||||
return result.astype(np.int64)
|
||||
|
||||
return Tracer._trace_numpy_operation(np.around, self, decimals=ndigits)
|
||||
|
||||
@@ -551,13 +567,38 @@ class Tracer:
|
||||
else Tracer._trace_numpy_operation(np.not_equal, self, self.sanitize(other))
|
||||
)
|
||||
|
||||
def astype(self, dtype: DTypeLike) -> "Tracer":
|
||||
def astype(self, dtype: Union[DTypeLike, Type["ScalarAnnotation"]]) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.astype(dtype).
|
||||
"""
|
||||
|
||||
normalized_dtype = np.dtype(dtype)
|
||||
if np.issubdtype(normalized_dtype, np.integer) and normalized_dtype != np.int64:
|
||||
if Tracer._is_direct:
|
||||
output_value = deepcopy(self.output)
|
||||
|
||||
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
|
||||
output_value.dtype = dtype.dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"`astype` method must be called with a concrete.numpy type "
|
||||
"for direct circuit definition (e.g., value.astype(cnp.uint4))"
|
||||
)
|
||||
|
||||
computation = Node.generic(
|
||||
"astype",
|
||||
[self.output],
|
||||
output_value,
|
||||
lambda x: x, # unused for direct definition
|
||||
)
|
||||
return Tracer(computation, [self])
|
||||
|
||||
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
|
||||
raise ValueError(
|
||||
"`astype` method must be called with a "
|
||||
"numpy type for compilation (e.g., value.astype(np.int64))"
|
||||
)
|
||||
|
||||
dtype = np.dtype(dtype).type
|
||||
if np.issubdtype(dtype, np.integer) and dtype != np.int64:
|
||||
print(
|
||||
"Warning: When using `value.astype(newtype)` "
|
||||
"with an integer newtype, "
|
||||
@@ -567,9 +608,9 @@ class Tracer:
|
||||
)
|
||||
|
||||
output_value = deepcopy(self.output)
|
||||
output_value.dtype = Value.of(normalized_dtype.type(0)).dtype
|
||||
output_value.dtype = Value.of(dtype(0)).dtype # type: ignore
|
||||
|
||||
if np.issubdtype(normalized_dtype.type, np.integer):
|
||||
if np.issubdtype(dtype, np.integer):
|
||||
|
||||
def evaluator(x, dtype):
|
||||
if np.any(np.isnan(x)):
|
||||
@@ -588,7 +629,7 @@ class Tracer:
|
||||
[self.output],
|
||||
output_value,
|
||||
evaluator,
|
||||
kwargs={"dtype": normalized_dtype.type},
|
||||
kwargs={"dtype": dtype},
|
||||
)
|
||||
return Tracer(computation, [self])
|
||||
|
||||
@@ -765,3 +806,23 @@ class Tracer:
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(np.transpose, self)
|
||||
|
||||
|
||||
class Annotation(Tracer):
|
||||
"""
|
||||
Base annotation for direct definition.
|
||||
"""
|
||||
|
||||
|
||||
class ScalarAnnotation(Annotation):
|
||||
"""
|
||||
Base scalar annotation for direct definition.
|
||||
"""
|
||||
|
||||
dtype: BaseDataType
|
||||
|
||||
|
||||
class TensorAnnotation(Annotation):
|
||||
"""
|
||||
Base tensor annotation for direct definition.
|
||||
"""
|
||||
|
||||
1222
concrete/numpy/tracing/typing.py
Normal file
1222
concrete/numpy/tracing/typing.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from concrete.numpy.compilation import DebugArtifacts, compiler
|
||||
from concrete.numpy import DebugArtifacts, compiler
|
||||
|
||||
|
||||
def test_artifacts_export(helpers):
|
||||
|
||||
@@ -8,8 +8,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server
|
||||
from concrete.numpy.compilation import compiler
|
||||
from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server, compiler
|
||||
|
||||
|
||||
def test_circuit_str(helpers):
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
"""
|
||||
Tests of `compiler` decorator.
|
||||
"""
|
||||
|
||||
from concrete.numpy.compilation import DebugArtifacts, compiler
|
||||
|
||||
|
||||
def test_call_compile(helpers):
|
||||
"""
|
||||
Test `__call__` and `compile` methods of `compiler` decorator back to back.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
for i in range(10):
|
||||
function(i)
|
||||
|
||||
circuit = function.compile(configuration=configuration)
|
||||
|
||||
sample = 5
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
def test_compiler_verbose_trace(helpers, capsys):
|
||||
"""
|
||||
Test `trace` method of `compiler` decorator with verbose flag.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
artifacts = DebugArtifacts()
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.trace(inputset, configuration, artifacts, show_graph=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
------------------------------------------------
|
||||
{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])}
|
||||
------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_compiler_verbose_compile(helpers, capsys):
|
||||
"""
|
||||
Test `compile` method of `compiler` decorator with verbose flag.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
artifacts = DebugArtifacts()
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.compile(inputset, configuration, artifacts, verbose=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip().startswith(
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
--------------------------------------------------------------------------------
|
||||
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
MLIR
|
||||
--------------------------------------------------------------------------------
|
||||
{artifacts.mlir_to_compile}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Optimizer
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_compiler_verbose_virtual_compile(helpers, capsys):
|
||||
"""
|
||||
Test `compile` method of `compiler` decorator with verbose flag.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
artifacts = DebugArtifacts()
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.compile(inputset, configuration, artifacts, verbose=True, virtual=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
------------------------------------------------
|
||||
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
|
||||
------------------------------------------------
|
||||
|
||||
MLIR
|
||||
------------------------------------------------
|
||||
Virtual circuits don't have MLIR.
|
||||
------------------------------------------------
|
||||
|
||||
Optimizer
|
||||
------------------------------------------------
|
||||
Virtual circuits don't have optimizer output.
|
||||
------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
)
|
||||
280
tests/compilation/test_decorators.py
Normal file
280
tests/compilation/test_decorators.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Tests of `compiler` and `circuit` decorators.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
def test_compiler_call_and_compile(helpers):
|
||||
"""
|
||||
Test `__call__` and `compile` methods of `compiler` decorator back to back.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
for i in range(10):
|
||||
function(i)
|
||||
|
||||
circuit = function.compile(configuration=configuration)
|
||||
|
||||
sample = 5
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
def test_compiler_verbose_trace(helpers, capsys):
|
||||
"""
|
||||
Test `trace` method of `compiler` decorator with verbose flag.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
artifacts = cnp.DebugArtifacts()
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.trace(inputset, configuration, artifacts, show_graph=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
------------------------------------------------
|
||||
{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])}
|
||||
------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_compiler_verbose_compile(helpers, capsys):
|
||||
"""
|
||||
Test `compile` method of `compiler` decorator with verbose flag.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
artifacts = cnp.DebugArtifacts()
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.compile(inputset, configuration, artifacts, verbose=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip().startswith(
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
--------------------------------------------------------------------------------
|
||||
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
MLIR
|
||||
--------------------------------------------------------------------------------
|
||||
{artifacts.mlir_to_compile}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Optimizer
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_compiler_verbose_virtual_compile(helpers, capsys):
|
||||
"""
|
||||
Test `compile` method of `compiler` decorator with verbose flag.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
artifacts = cnp.DebugArtifacts()
|
||||
|
||||
@cnp.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
function.compile(inputset, configuration, artifacts, verbose=True, virtual=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
------------------------------------------------
|
||||
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
|
||||
------------------------------------------------
|
||||
|
||||
MLIR
|
||||
------------------------------------------------
|
||||
Virtual circuits don't have MLIR.
|
||||
------------------------------------------------
|
||||
|
||||
Optimizer
|
||||
------------------------------------------------
|
||||
Virtual circuits don't have optimizer output.
|
||||
------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_circuit(helpers):
|
||||
"""
|
||||
Test circuit decorator.
|
||||
"""
|
||||
|
||||
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
|
||||
def circuit1(x: cnp.uint2):
|
||||
return x + 42
|
||||
|
||||
helpers.check_str(
|
||||
str(circuit1),
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint2>
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
return %2
|
||||
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
|
||||
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
|
||||
def circuit2(x: cnp.tensor[cnp.uint2, 3, 2]):
|
||||
return x + 42
|
||||
|
||||
helpers.check_str(
|
||||
str(circuit2),
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedTensor<uint2, shape=(3, 2)>
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedTensor<uint6, shape=(3, 2)>
|
||||
return %2
|
||||
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
|
||||
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
|
||||
def circuit3(x: cnp.uint3):
|
||||
def square(x):
|
||||
return x**2
|
||||
|
||||
return cnp.univariate(square, outputs=cnp.uint7)(x)
|
||||
|
||||
helpers.check_str(
|
||||
str(circuit3),
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint3>
|
||||
%1 = square(%0) # EncryptedScalar<uint7>
|
||||
return %1
|
||||
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
|
||||
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
|
||||
def circuit4(x: cnp.uint3):
|
||||
return ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(cnp.uint3)
|
||||
|
||||
helpers.check_str(
|
||||
str(circuit4),
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint3>
|
||||
%1 = subgraph(%0) # EncryptedScalar<uint3>
|
||||
return %1
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%1 = subgraph(%0):
|
||||
|
||||
%0 = 2 # ClearScalar<uint2>
|
||||
%1 = 2 # ClearScalar<uint2>
|
||||
%2 = input # EncryptedScalar<uint3>
|
||||
%3 = sin(%2) # EncryptedScalar<float64>
|
||||
%4 = cos(%2) # EncryptedScalar<float64>
|
||||
%5 = power(%3, %0) # EncryptedScalar<float64>
|
||||
%6 = power(%4, %1) # EncryptedScalar<float64>
|
||||
%7 = add(%5, %6) # EncryptedScalar<float64>
|
||||
%8 = astype(%7) # EncryptedScalar<uint3>
|
||||
return %8
|
||||
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
|
||||
def test_bad_circuit(helpers):
|
||||
"""
|
||||
Test circuit decorator with bad parameters.
|
||||
"""
|
||||
|
||||
# bad annotation
|
||||
# --------------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
|
||||
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
|
||||
def circuit1(x: int):
|
||||
return x + 42
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
f"Annotation {str(int)} for argument 'x' is not valid "
|
||||
f"(please use a cnp type such as `cnp.uint4` or 'cnp.tensor[cnp.uint4, 3, 2]')"
|
||||
)
|
||||
|
||||
# missing encryption status
|
||||
# -------------------------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
|
||||
@cnp.circuit({}, helpers.configuration())
|
||||
def circuit2(x: cnp.uint3):
|
||||
return x + 42
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Encryption status of parameter 'x' of function 'circuit2' is not provided"
|
||||
)
|
||||
|
||||
# bad astype
|
||||
# ----------
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
|
||||
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
|
||||
def circuit3(x: cnp.uint3):
|
||||
return x.astype(np.int64)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"`astype` method must be called with a concrete.numpy type "
|
||||
"for direct circuit definition (e.g., value.astype(cnp.uint4))"
|
||||
)
|
||||
|
||||
# round
|
||||
# -----
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
|
||||
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
|
||||
def circuit4(x: cnp.uint3):
|
||||
return round(x)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"'round(x)' cannot be used in direct definition (you may use np.around instead)"
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Tests of LookupTable.
|
||||
Tests of 'array' extension.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Tests of LookupTable.
|
||||
Tests of 'LookupTable' extension.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
24
tests/extensions/test_univariate.py
Normal file
24
tests/extensions/test_univariate.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Tests of 'univariate' extension.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
def test_bad_univariate(helpers):
|
||||
"""
|
||||
Test 'univariate' extension with bad parameters.
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
|
||||
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
|
||||
def function(x: cnp.uint3):
|
||||
return cnp.univariate(lambda x: x**2)(x)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Univariate extension requires `outputs` argument for direct circuit definition "
|
||||
"(e.g., cnp.univariate(function, outputs=cnp.uint4)(x))"
|
||||
)
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from concrete.numpy.dtypes import UnsignedInteger
|
||||
from concrete.numpy.tracing import Tracer
|
||||
from concrete.numpy.tracing.typing import uint4
|
||||
from concrete.numpy.values import EncryptedTensor
|
||||
|
||||
|
||||
@@ -43,6 +44,13 @@ from concrete.numpy.values import EncryptedTensor
|
||||
RuntimeError,
|
||||
"Only __call__ hook is supported for numpy ufuncs",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.astype(uint4),
|
||||
{"x": EncryptedTensor(UnsignedInteger(7), shape=(4,))},
|
||||
ValueError,
|
||||
"`astype` method must be called with a "
|
||||
"numpy type for compilation (e.g., value.astype(np.int64))",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracer_bad_trace(function, parameters, expected_error, expected_message):
|
||||
|
||||
54
tests/tracing/test_typing.py
Normal file
54
tests/tracing/test_typing.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
Test type annotations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
def test_bad_tensor():
|
||||
"""
|
||||
Test `tensor` type with bad parameters
|
||||
"""
|
||||
|
||||
# invalid dtype
|
||||
# -------------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
|
||||
def case1(x: cnp.tensor[int]):
|
||||
return x
|
||||
|
||||
case1(None)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"First argument to tensor annotations should be a "
|
||||
"concrete-numpy data type (e.g., cnp.uint4) not int"
|
||||
)
|
||||
|
||||
# no shape
|
||||
# --------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
|
||||
def case2(x: cnp.tensor[cnp.uint3]):
|
||||
return x
|
||||
|
||||
case2(None)
|
||||
|
||||
assert str(excinfo.value) == (
|
||||
"Tensor annotations should have a shape (e.g., cnp.tensor[cnp.uint4, 3, 2])"
|
||||
)
|
||||
|
||||
# bad shape
|
||||
# ---------
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
|
||||
def case3(x: cnp.tensor[cnp.uint3, 1.5]):
|
||||
return x
|
||||
|
||||
case3(None)
|
||||
|
||||
assert str(excinfo.value) == "Tensor annotation shape elements must be 'int'"
|
||||
Reference in New Issue
Block a user