From 66c707cd697c3b1fe8a879466514905f66cb0de9 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 7 Oct 2022 10:55:03 +0200 Subject: [PATCH] feat: introduce circuit decorator to directly define circuits --- concrete/numpy/__init__.py | 135 +- concrete/numpy/compilation/__init__.py | 1 - concrete/numpy/compilation/compiler.py | 65 + .../{decorator.py => decorators.py} | 68 +- concrete/numpy/extensions/univariate.py | 23 +- concrete/numpy/tracing/__init__.py | 2 +- concrete/numpy/tracing/tracer.py | 87 +- concrete/numpy/tracing/typing.py | 1222 +++++++++++++++++ tests/compilation/test_artifacts.py | 2 +- tests/compilation/test_circuit.py | 3 +- tests/compilation/test_decorator.py | 127 -- tests/compilation/test_decorators.py | 280 ++++ tests/extensions/test_array.py | 2 +- tests/extensions/test_table.py | 2 +- tests/extensions/test_univariate.py | 24 + tests/tracing/test_tracer.py | 8 + tests/tracing/test_typing.py | 54 + 17 files changed, 1950 insertions(+), 155 deletions(-) rename concrete/numpy/compilation/{decorator.py => decorators.py} (59%) create mode 100644 concrete/numpy/tracing/typing.py delete mode 100644 tests/compilation/test_decorator.py create mode 100644 tests/compilation/test_decorators.py create mode 100644 tests/extensions/test_univariate.py create mode 100644 tests/tracing/test_typing.py diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index db447abeb..2da8eb445 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -13,8 +13,141 @@ from .compilation import ( DebugArtifacts, EncryptionStatus, Server, - compiler, ) +from .compilation.decorators import circuit, compiler from .extensions import LookupTable, array, one, ones, univariate, zero, zeros from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH from .representation import Graph +from .tracing.typing import ( + f32, + f64, + int1, + int2, + int3, + int4, + int5, + int6, + int7, + int8, + int9, + int10, + int11, + int12, + int13, + int14, + int15, + int16, + int17, + int18, + int19, + int20, + int21, + int22, + int23, + int24, + int25, + int26, + int27, + int28, + int29, + int30, + int31, + int32, + int33, + int34, + int35, + int36, + int37, + int38, + int39, + int40, + int41, + int42, + int43, + int44, + int45, + int46, + int47, + int48, + int49, + int50, + int51, + int52, + int53, + int54, + int55, + int56, + int57, + int58, + int59, + int60, + int61, + int62, + int63, + int64, + tensor, + uint1, + uint2, + uint3, + uint4, + uint5, + uint6, + uint7, + uint8, + uint9, + uint10, + uint11, + uint12, + uint13, + uint14, + uint15, + uint16, + uint17, + uint18, + uint19, + uint20, + uint21, + uint22, + uint23, + uint24, + uint25, + uint26, + uint27, + uint28, + uint29, + uint30, + uint31, + uint32, + uint33, + uint34, + uint35, + uint36, + uint37, + uint38, + uint39, + uint40, + uint41, + uint42, + uint43, + uint44, + uint45, + uint46, + uint47, + uint48, + uint49, + uint50, + uint51, + uint52, + uint53, + uint54, + uint55, + uint56, + uint57, + uint58, + uint59, + uint60, + uint61, + uint62, + uint63, + uint64, +) diff --git a/concrete/numpy/compilation/__init__.py b/concrete/numpy/compilation/__init__.py index 147137419..05b4f1770 100644 --- a/concrete/numpy/compilation/__init__.py +++ b/concrete/numpy/compilation/__init__.py @@ -7,6 +7,5 @@ from .circuit import Circuit from .client import Client from .compiler import Compiler, EncryptionStatus from .configuration import Configuration -from .decorator import compiler from .server import Server from .specs import ClientSpecs diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index 4a140c696..e73132c77 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -45,6 +45,56 @@ class Compiler: inputset: List[Any] graph: Optional[Graph] + _is_direct: bool + _parameter_values: Dict[str, Value] + + @staticmethod + def assemble( + function: Callable, + parameter_values: Dict[str, Value], + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, + ) -> Circuit: + """ + Assemble a circuit from the raw parameter values, used in direct circuit definition. + + Args: + function (Callable): + function to convert to a circuit + + parameter_values (Dict[str, Value]): + parameter values of the function + + configuration(Optional[Configuration], default = None): + configuration to use + + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process + + kwargs (Dict[str, Any]): + configuration options to overwrite + + Returns: + Circuit: + assembled circuit + """ + + compiler = Compiler( + function, + { + name: "encrypted" if value.is_encrypted else "clear" + for name, value in parameter_values.items() + }, + ) + + # pylint: disable=protected-access + compiler._is_direct = True + compiler._parameter_values = parameter_values + # pylint: enable=protected-access + + return compiler.compile(None, configuration, artifacts, **kwargs) + def __init__( self, function: Callable, @@ -102,6 +152,9 @@ class Compiler: self.inputset = [] self.graph = None + self._is_direct = False + self._parameter_values = {} + def __call__( self, *args: Any, @@ -171,6 +224,18 @@ class Compiler: optional inputset to extend accumulated inputset before bounds measurement """ + if self._is_direct: + + self.graph = Tracer.trace(self.function, self._parameter_values, is_direct=True) + if self.artifacts is not None: + self.artifacts.add_graph("initial", self.graph) # pragma: no cover + + fuse(self.graph, self.artifacts) + if self.artifacts is not None: + self.artifacts.add_graph("final", self.graph) # pragma: no cover + + return + if inputset is not None: previous_inputset_length = len(self.inputset) for index, sample in enumerate(iter(inputset)): diff --git a/concrete/numpy/compilation/decorator.py b/concrete/numpy/compilation/decorators.py similarity index 59% rename from concrete/numpy/compilation/decorator.py rename to concrete/numpy/compilation/decorators.py index 8bdbffdc0..0f6a9473e 100644 --- a/concrete/numpy/compilation/decorator.py +++ b/concrete/numpy/compilation/decorators.py @@ -1,22 +1,82 @@ """ -Declaration of `compiler` decorator. +Declaration of `circuit` and `compiler` decorators. """ -from typing import Any, Callable, Iterable, Mapping, Optional, Tuple, Union +import inspect +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union from ..representation import Graph +from ..tracing.typing import ScalarAnnotation +from ..values import Value from .artifacts import DebugArtifacts from .circuit import Circuit from .compiler import Compiler, EncryptionStatus from .configuration import Configuration -def compiler(parameters: Mapping[str, EncryptionStatus]): +def circuit( + parameters: Mapping[str, Union[str, EncryptionStatus]], + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, +): + """ + Provide a direct interface for compilation. + + Args: + parameters (Mapping[str, Union[str, EncryptionStatus]]): + encryption statuses of the parameters of the function to compile + + configuration(Optional[Configuration], default = None): + configuration to use + + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process + + kwargs (Dict[str, Any]): + configuration options to overwrite + """ + + def decoration(function: Callable): + signature = inspect.signature(function) + + parameter_values: Dict[str, Value] = {} + for name, details in signature.parameters.items(): + if name not in parameters: + continue + + annotation = details.annotation + + is_value = isinstance(annotation, Value) + is_scalar_annotation = isinstance(annotation, type) and issubclass( + annotation, ScalarAnnotation + ) + + if not (is_value or is_scalar_annotation): + raise ValueError( + f"Annotation {annotation} for argument '{name}' is not valid " + f"(please use a cnp type such as " + f"`cnp.uint4` or 'cnp.tensor[cnp.uint4, 3, 2]')" + ) + + parameter_values[name] = ( + annotation if is_value else Value(annotation.dtype, shape=(), is_encrypted=False) + ) + + status = EncryptionStatus(parameters[name].lower()) + parameter_values[name].is_encrypted = status == "encrypted" + + return Compiler.assemble(function, parameter_values, configuration, artifacts, **kwargs) + + return decoration + + +def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]): """ Provide an easy interface for compilation. Args: - parameters (Dict[str, EncryptionStatus]): + parameters (Mapping[str, Union[str, EncryptionStatus]]): encryption statuses of the parameters of the function to compile """ diff --git a/concrete/numpy/extensions/univariate.py b/concrete/numpy/extensions/univariate.py index acc431545..86fb1f8c1 100644 --- a/concrete/numpy/extensions/univariate.py +++ b/concrete/numpy/extensions/univariate.py @@ -2,18 +2,19 @@ Declaration of `univariate` function. """ -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Type, Union import numpy as np -from ..dtypes import Float +from ..dtypes import BaseDataType, Float from ..representation import Node -from ..tracing import Tracer +from ..tracing import ScalarAnnotation, Tracer from ..values import Value def univariate( function: Callable[[Any], Any], + outputs: Optional[Union[BaseDataType, Type[ScalarAnnotation]]] = None, ) -> Callable[[Union[Tracer, Any]], Union[Tracer, Any]]: """ Wrap a univariate function so that it is traced into a single generic node. @@ -22,6 +23,9 @@ def univariate( function (Callable[[Any], Any]): univariate function to wrap + outputs (Optional[Union[BaseDataType, Type[ScalarAnnotation]]], default = None): + data type of the result, unused during compilation, required for direct definition + Returns: Callable[[Union[Tracer, Any]], Union[Tracer, Any]]: another univariate function that can be called with a Tracer as well @@ -57,6 +61,19 @@ def univariate( if output_value.shape != x.output.shape: raise ValueError(f"Function {function.__name__} cannot be used with cnp.univariate") + # pylint: disable=protected-access + is_direct = Tracer._is_direct + # pylint: enable=protected-access + + if is_direct: + if outputs is None: + raise ValueError( + "Univariate extension requires " + "`outputs` argument for direct circuit definition " + "(e.g., cnp.univariate(function, outputs=cnp.uint4)(x))" + ) + output_value.dtype = outputs if isinstance(outputs, BaseDataType) else outputs.dtype + computation = Node.generic( function.__name__, [x.output], diff --git a/concrete/numpy/tracing/__init__.py b/concrete/numpy/tracing/__init__.py index 02ed788a5..e0f30edfa 100644 --- a/concrete/numpy/tracing/__init__.py +++ b/concrete/numpy/tracing/__init__.py @@ -2,4 +2,4 @@ Provide `function` to `computation graph` functionality. """ -from .tracer import Tracer +from .tracer import ScalarAnnotation, TensorAnnotation, Tracer diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index b98aa0987..31a6a44a5 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -4,13 +4,13 @@ Declaration of `Tracer` class. import inspect from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import networkx as nx import numpy as np from numpy.typing import DTypeLike -from ..dtypes import Float, Integer +from ..dtypes import BaseDataType, Float, Integer from ..internal.utils import assert_that from ..representation import Graph, Node, Operation from ..representation.utils import format_indexing_element @@ -29,13 +29,12 @@ class Tracer: # property to keep track of assignments last_version: Optional["Tracer"] = None - # variable to control the behavior of __eq__ - # so that it can be traced but still allow - # using Tracers in dicts when not tracing + # variables to control the behavior of certain functions _is_tracing: bool = False + _is_direct: bool = False @staticmethod - def trace(function: Callable, parameters: Dict[str, Value]) -> Graph: + def trace(function: Callable, parameters: Dict[str, Value], is_direct: bool = False) -> Graph: """ Trace `function` and create the `Graph` that represents it. @@ -47,6 +46,9 @@ class Tracer: parameters of function to trace e.g. parameter x is an EncryptedScalar holding a 7-bit UnsignedInteger + is_direct (bool, default = False): + whether the tracing is done on actual parameters or placeholders + Returns: Graph: computation graph corresponding to `function` @@ -67,6 +69,8 @@ class Tracer: arguments[param] = Tracer(node, []) input_indices[node] = index + Tracer._is_direct = is_direct + Tracer._is_tracing = True output_tracers: Any = function(**arguments) Tracer._is_tracing = False @@ -383,6 +387,13 @@ class Tracer: output_value = Value.of(evaluation) output_value.is_encrypted = any(tracer.output.is_encrypted for tracer in tracers) + if Tracer._is_direct and isinstance(output_value.dtype, Integer): + resulting_bit_width = 0 + for tracer in tracers: + assert isinstance(tracer.output.dtype, Integer) + resulting_bit_width = max(resulting_bit_width, tracer.output.dtype.bit_width) + output_value.dtype.bit_width = resulting_bit_width + computation = Node.generic( operation.__name__, [tracer.output for tracer in tracers], @@ -488,7 +499,12 @@ class Tracer: def __round__(self, ndigits=None): if ndigits is None: - return Tracer._trace_numpy_operation(np.around, self).astype(np.int64) + result = Tracer._trace_numpy_operation(np.around, self) + if self._is_direct: + raise RuntimeError( + "'round(x)' cannot be used in direct definition (you may use np.around instead)" + ) + return result.astype(np.int64) return Tracer._trace_numpy_operation(np.around, self, decimals=ndigits) @@ -551,13 +567,38 @@ class Tracer: else Tracer._trace_numpy_operation(np.not_equal, self, self.sanitize(other)) ) - def astype(self, dtype: DTypeLike) -> "Tracer": + def astype(self, dtype: Union[DTypeLike, Type["ScalarAnnotation"]]) -> "Tracer": """ Trace numpy.ndarray.astype(dtype). """ - normalized_dtype = np.dtype(dtype) - if np.issubdtype(normalized_dtype, np.integer) and normalized_dtype != np.int64: + if Tracer._is_direct: + output_value = deepcopy(self.output) + + if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation): + output_value.dtype = dtype.dtype + else: + raise ValueError( + "`astype` method must be called with a concrete.numpy type " + "for direct circuit definition (e.g., value.astype(cnp.uint4))" + ) + + computation = Node.generic( + "astype", + [self.output], + output_value, + lambda x: x, # unused for direct definition + ) + return Tracer(computation, [self]) + + if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation): + raise ValueError( + "`astype` method must be called with a " + "numpy type for compilation (e.g., value.astype(np.int64))" + ) + + dtype = np.dtype(dtype).type + if np.issubdtype(dtype, np.integer) and dtype != np.int64: print( "Warning: When using `value.astype(newtype)` " "with an integer newtype, " @@ -567,9 +608,9 @@ class Tracer: ) output_value = deepcopy(self.output) - output_value.dtype = Value.of(normalized_dtype.type(0)).dtype + output_value.dtype = Value.of(dtype(0)).dtype # type: ignore - if np.issubdtype(normalized_dtype.type, np.integer): + if np.issubdtype(dtype, np.integer): def evaluator(x, dtype): if np.any(np.isnan(x)): @@ -588,7 +629,7 @@ class Tracer: [self.output], output_value, evaluator, - kwargs={"dtype": normalized_dtype.type}, + kwargs={"dtype": dtype}, ) return Tracer(computation, [self]) @@ -765,3 +806,23 @@ class Tracer: """ return Tracer._trace_numpy_operation(np.transpose, self) + + +class Annotation(Tracer): + """ + Base annotation for direct definition. + """ + + +class ScalarAnnotation(Annotation): + """ + Base scalar annotation for direct definition. + """ + + dtype: BaseDataType + + +class TensorAnnotation(Annotation): + """ + Base tensor annotation for direct definition. + """ diff --git a/concrete/numpy/tracing/typing.py b/concrete/numpy/tracing/typing.py new file mode 100644 index 000000000..4574f51ff --- /dev/null +++ b/concrete/numpy/tracing/typing.py @@ -0,0 +1,1222 @@ +""" +Declaration of type annotation. +""" + +from typing import Any + +from ..dtypes import Float, SignedInteger, UnsignedInteger +from ..values import Value +from .tracer import ScalarAnnotation, TensorAnnotation + +# pylint: disable=function-redefined,invalid-name,no-self-use,too-many-lines,using-constant-test + + +# We'll pull a little trick on mypy +# Basically, this branch is never executed during runtime +# But, mypy will use the information within anyway +# So, it'll think our types are `Any` and it'll stop complaining when used with numpy + +if False: + + f32 = Any + f64 = Any + + int1 = Any + int2 = Any + int3 = Any + int4 = Any + int5 = Any + int6 = Any + int7 = Any + int8 = Any + int9 = Any + int10 = Any + int11 = Any + int12 = Any + int13 = Any + int14 = Any + int15 = Any + int16 = Any + int17 = Any + int18 = Any + int19 = Any + int20 = Any + int21 = Any + int22 = Any + int23 = Any + int24 = Any + int25 = Any + int26 = Any + int27 = Any + int28 = Any + int29 = Any + int30 = Any + int31 = Any + int32 = Any + int33 = Any + int34 = Any + int35 = Any + int36 = Any + int37 = Any + int38 = Any + int39 = Any + int40 = Any + int41 = Any + int42 = Any + int43 = Any + int44 = Any + int45 = Any + int46 = Any + int47 = Any + int48 = Any + int49 = Any + int50 = Any + int51 = Any + int52 = Any + int53 = Any + int54 = Any + int55 = Any + int56 = Any + int57 = Any + int58 = Any + int59 = Any + int60 = Any + int61 = Any + int62 = Any + int63 = Any + int64 = Any + + uint1 = Any + uint2 = Any + uint3 = Any + uint4 = Any + uint5 = Any + uint6 = Any + uint7 = Any + uint8 = Any + uint9 = Any + uint10 = Any + uint11 = Any + uint12 = Any + uint13 = Any + uint14 = Any + uint15 = Any + uint16 = Any + uint17 = Any + uint18 = Any + uint19 = Any + uint20 = Any + uint21 = Any + uint22 = Any + uint23 = Any + uint24 = Any + uint25 = Any + uint26 = Any + uint27 = Any + uint28 = Any + uint29 = Any + uint30 = Any + uint31 = Any + uint32 = Any + uint33 = Any + uint34 = Any + uint35 = Any + uint36 = Any + uint37 = Any + uint38 = Any + uint39 = Any + uint40 = Any + uint41 = Any + uint42 = Any + uint43 = Any + uint44 = Any + uint45 = Any + uint46 = Any + uint47 = Any + uint48 = Any + uint49 = Any + uint50 = Any + uint51 = Any + uint52 = Any + uint53 = Any + uint54 = Any + uint55 = Any + uint56 = Any + uint57 = Any + uint58 = Any + uint59 = Any + uint60 = Any + uint61 = Any + uint62 = Any + uint63 = Any + uint64 = Any + tensor = Any + + +class f32(ScalarAnnotation): # type: ignore + """ + Scalar f32 annotation. + """ + + dtype = Float(32) + + +class f64(ScalarAnnotation): # type: ignore + """ + Scalar f64 annotation. + """ + + dtype = Float(64) + + +class int1(ScalarAnnotation): # type: ignore + """ + Scalar int1 annotation. + """ + + dtype = SignedInteger(1) + + +class int2(ScalarAnnotation): # type: ignore + """ + Scalar int2 annotation. + """ + + dtype = SignedInteger(2) + + +class int3(ScalarAnnotation): # type: ignore + """ + Scalar int3 annotation. + """ + + dtype = SignedInteger(3) + + +class int4(ScalarAnnotation): # type: ignore + """ + Scalar int4 annotation. + """ + + dtype = SignedInteger(4) + + +class int5(ScalarAnnotation): # type: ignore + """ + Scalar int5 annotation. + """ + + dtype = SignedInteger(5) + + +class int6(ScalarAnnotation): # type: ignore + """ + Scalar int6 annotation. + """ + + dtype = SignedInteger(6) + + +class int7(ScalarAnnotation): # type: ignore + """ + Scalar int7 annotation. + """ + + dtype = SignedInteger(7) + + +class int8(ScalarAnnotation): # type: ignore + """ + Scalar int8 annotation. + """ + + dtype = SignedInteger(8) + + +class int9(ScalarAnnotation): # type: ignore + """ + Scalar int9 annotation. + """ + + dtype = SignedInteger(9) + + +class int10(ScalarAnnotation): # type: ignore + """ + Scalar int10 annotation. + """ + + dtype = SignedInteger(10) + + +class int11(ScalarAnnotation): # type: ignore + """ + Scalar int11 annotation. + """ + + dtype = SignedInteger(11) + + +class int12(ScalarAnnotation): # type: ignore + """ + Scalar int12 annotation. + """ + + dtype = SignedInteger(12) + + +class int13(ScalarAnnotation): # type: ignore + """ + Scalar int13 annotation. + """ + + dtype = SignedInteger(13) + + +class int14(ScalarAnnotation): # type: ignore + """ + Scalar int14 annotation. + """ + + dtype = SignedInteger(14) + + +class int15(ScalarAnnotation): # type: ignore + """ + Scalar int15 annotation. + """ + + dtype = SignedInteger(15) + + +class int16(ScalarAnnotation): # type: ignore + """ + Scalar int16 annotation. + """ + + dtype = SignedInteger(16) + + +class int17(ScalarAnnotation): # type: ignore + """ + Scalar int17 annotation. + """ + + dtype = SignedInteger(17) + + +class int18(ScalarAnnotation): # type: ignore + """ + Scalar int18 annotation. + """ + + dtype = SignedInteger(18) + + +class int19(ScalarAnnotation): # type: ignore + """ + Scalar int19 annotation. + """ + + dtype = SignedInteger(19) + + +class int20(ScalarAnnotation): # type: ignore + """ + Scalar int20 annotation. + """ + + dtype = SignedInteger(20) + + +class int21(ScalarAnnotation): # type: ignore + """ + Scalar int21 annotation. + """ + + dtype = SignedInteger(21) + + +class int22(ScalarAnnotation): # type: ignore + """ + Scalar int22 annotation. + """ + + dtype = SignedInteger(22) + + +class int23(ScalarAnnotation): # type: ignore + """ + Scalar int23 annotation. + """ + + dtype = SignedInteger(23) + + +class int24(ScalarAnnotation): # type: ignore + """ + Scalar int24 annotation. + """ + + dtype = SignedInteger(24) + + +class int25(ScalarAnnotation): # type: ignore + """ + Scalar int25 annotation. + """ + + dtype = SignedInteger(25) + + +class int26(ScalarAnnotation): # type: ignore + """ + Scalar int26 annotation. + """ + + dtype = SignedInteger(26) + + +class int27(ScalarAnnotation): # type: ignore + """ + Scalar int27 annotation. + """ + + dtype = SignedInteger(27) + + +class int28(ScalarAnnotation): # type: ignore + """ + Scalar int28 annotation. + """ + + dtype = SignedInteger(28) + + +class int29(ScalarAnnotation): # type: ignore + """ + Scalar int29 annotation. + """ + + dtype = SignedInteger(29) + + +class int30(ScalarAnnotation): # type: ignore + """ + Scalar int30 annotation. + """ + + dtype = SignedInteger(30) + + +class int31(ScalarAnnotation): # type: ignore + """ + Scalar int31 annotation. + """ + + dtype = SignedInteger(31) + + +class int32(ScalarAnnotation): # type: ignore + """ + Scalar int32 annotation. + """ + + dtype = SignedInteger(32) + + +class int33(ScalarAnnotation): # type: ignore + """ + Scalar int33 annotation. + """ + + dtype = SignedInteger(33) + + +class int34(ScalarAnnotation): # type: ignore + """ + Scalar int34 annotation. + """ + + dtype = SignedInteger(34) + + +class int35(ScalarAnnotation): # type: ignore + """ + Scalar int35 annotation. + """ + + dtype = SignedInteger(35) + + +class int36(ScalarAnnotation): # type: ignore + """ + Scalar int36 annotation. + """ + + dtype = SignedInteger(36) + + +class int37(ScalarAnnotation): # type: ignore + """ + Scalar int37 annotation. + """ + + dtype = SignedInteger(37) + + +class int38(ScalarAnnotation): # type: ignore + """ + Scalar int38 annotation. + """ + + dtype = SignedInteger(38) + + +class int39(ScalarAnnotation): # type: ignore + """ + Scalar int39 annotation. + """ + + dtype = SignedInteger(39) + + +class int40(ScalarAnnotation): # type: ignore + """ + Scalar int40 annotation. + """ + + dtype = SignedInteger(40) + + +class int41(ScalarAnnotation): # type: ignore + """ + Scalar int41 annotation. + """ + + dtype = SignedInteger(41) + + +class int42(ScalarAnnotation): # type: ignore + """ + Scalar int42 annotation. + """ + + dtype = SignedInteger(42) + + +class int43(ScalarAnnotation): # type: ignore + """ + Scalar int43 annotation. + """ + + dtype = SignedInteger(43) + + +class int44(ScalarAnnotation): # type: ignore + """ + Scalar int44 annotation. + """ + + dtype = SignedInteger(44) + + +class int45(ScalarAnnotation): # type: ignore + """ + Scalar int45 annotation. + """ + + dtype = SignedInteger(45) + + +class int46(ScalarAnnotation): # type: ignore + """ + Scalar int46 annotation. + """ + + dtype = SignedInteger(46) + + +class int47(ScalarAnnotation): # type: ignore + """ + Scalar int47 annotation. + """ + + dtype = SignedInteger(47) + + +class int48(ScalarAnnotation): # type: ignore + """ + Scalar int48 annotation. + """ + + dtype = SignedInteger(48) + + +class int49(ScalarAnnotation): # type: ignore + """ + Scalar int49 annotation. + """ + + dtype = SignedInteger(49) + + +class int50(ScalarAnnotation): # type: ignore + """ + Scalar int50 annotation. + """ + + dtype = SignedInteger(50) + + +class int51(ScalarAnnotation): # type: ignore + """ + Scalar int51 annotation. + """ + + dtype = SignedInteger(51) + + +class int52(ScalarAnnotation): # type: ignore + """ + Scalar int52 annotation. + """ + + dtype = SignedInteger(52) + + +class int53(ScalarAnnotation): # type: ignore + """ + Scalar int53 annotation. + """ + + dtype = SignedInteger(53) + + +class int54(ScalarAnnotation): # type: ignore + """ + Scalar int54 annotation. + """ + + dtype = SignedInteger(54) + + +class int55(ScalarAnnotation): # type: ignore + """ + Scalar int55 annotation. + """ + + dtype = SignedInteger(55) + + +class int56(ScalarAnnotation): # type: ignore + """ + Scalar int56 annotation. + """ + + dtype = SignedInteger(56) + + +class int57(ScalarAnnotation): # type: ignore + """ + Scalar int57 annotation. + """ + + dtype = SignedInteger(57) + + +class int58(ScalarAnnotation): # type: ignore + """ + Scalar int58 annotation. + """ + + dtype = SignedInteger(58) + + +class int59(ScalarAnnotation): # type: ignore + """ + Scalar int59 annotation. + """ + + dtype = SignedInteger(59) + + +class int60(ScalarAnnotation): # type: ignore + """ + Scalar int60 annotation. + """ + + dtype = SignedInteger(60) + + +class int61(ScalarAnnotation): # type: ignore + """ + Scalar int61 annotation. + """ + + dtype = SignedInteger(61) + + +class int62(ScalarAnnotation): # type: ignore + """ + Scalar int62 annotation. + """ + + dtype = SignedInteger(62) + + +class int63(ScalarAnnotation): # type: ignore + """ + Scalar int63 annotation. + """ + + dtype = SignedInteger(63) + + +class int64(ScalarAnnotation): # type: ignore + """ + Scalar int64 annotation. + """ + + dtype = SignedInteger(64) + + +class uint1(ScalarAnnotation): # type: ignore + """ + Scalar uint1 annotation. + """ + + dtype = UnsignedInteger(1) + + +class uint2(ScalarAnnotation): # type: ignore + """ + Scalar uint2 annotation. + """ + + dtype = UnsignedInteger(2) + + +class uint3(ScalarAnnotation): # type: ignore + """ + Scalar uint3 annotation. + """ + + dtype = UnsignedInteger(3) + + +class uint4(ScalarAnnotation): # type: ignore + """ + Scalar uint4 annotation. + """ + + dtype = UnsignedInteger(4) + + +class uint5(ScalarAnnotation): # type: ignore + """ + Scalar uint5 annotation. + """ + + dtype = UnsignedInteger(5) + + +class uint6(ScalarAnnotation): # type: ignore + """ + Scalar uint6 annotation. + """ + + dtype = UnsignedInteger(6) + + +class uint7(ScalarAnnotation): # type: ignore + """ + Scalar uint7 annotation. + """ + + dtype = UnsignedInteger(7) + + +class uint8(ScalarAnnotation): # type: ignore + """ + Scalar uint8 annotation. + """ + + dtype = UnsignedInteger(8) + + +class uint9(ScalarAnnotation): # type: ignore + """ + Scalar uint9 annotation. + """ + + dtype = UnsignedInteger(9) + + +class uint10(ScalarAnnotation): # type: ignore + """ + Scalar uint10 annotation. + """ + + dtype = UnsignedInteger(10) + + +class uint11(ScalarAnnotation): # type: ignore + """ + Scalar uint11 annotation. + """ + + dtype = UnsignedInteger(11) + + +class uint12(ScalarAnnotation): # type: ignore + """ + Scalar uint12 annotation. + """ + + dtype = UnsignedInteger(12) + + +class uint13(ScalarAnnotation): # type: ignore + """ + Scalar uint13 annotation. + """ + + dtype = UnsignedInteger(13) + + +class uint14(ScalarAnnotation): # type: ignore + """ + Scalar uint14 annotation. + """ + + dtype = UnsignedInteger(14) + + +class uint15(ScalarAnnotation): # type: ignore + """ + Scalar uint15 annotation. + """ + + dtype = UnsignedInteger(15) + + +class uint16(ScalarAnnotation): # type: ignore + """ + Scalar uint16 annotation. + """ + + dtype = UnsignedInteger(16) + + +class uint17(ScalarAnnotation): # type: ignore + """ + Scalar uint17 annotation. + """ + + dtype = UnsignedInteger(17) + + +class uint18(ScalarAnnotation): # type: ignore + """ + Scalar uint18 annotation. + """ + + dtype = UnsignedInteger(18) + + +class uint19(ScalarAnnotation): # type: ignore + """ + Scalar uint19 annotation. + """ + + dtype = UnsignedInteger(19) + + +class uint20(ScalarAnnotation): # type: ignore + """ + Scalar uint20 annotation. + """ + + dtype = UnsignedInteger(20) + + +class uint21(ScalarAnnotation): # type: ignore + """ + Scalar uint21 annotation. + """ + + dtype = UnsignedInteger(21) + + +class uint22(ScalarAnnotation): # type: ignore + """ + Scalar uint22 annotation. + """ + + dtype = UnsignedInteger(22) + + +class uint23(ScalarAnnotation): # type: ignore + """ + Scalar uint23 annotation. + """ + + dtype = UnsignedInteger(23) + + +class uint24(ScalarAnnotation): # type: ignore + """ + Scalar uint24 annotation. + """ + + dtype = UnsignedInteger(24) + + +class uint25(ScalarAnnotation): # type: ignore + """ + Scalar uint25 annotation. + """ + + dtype = UnsignedInteger(25) + + +class uint26(ScalarAnnotation): # type: ignore + """ + Scalar uint26 annotation. + """ + + dtype = UnsignedInteger(26) + + +class uint27(ScalarAnnotation): # type: ignore + """ + Scalar uint27 annotation. + """ + + dtype = UnsignedInteger(27) + + +class uint28(ScalarAnnotation): # type: ignore + """ + Scalar uint28 annotation. + """ + + dtype = UnsignedInteger(28) + + +class uint29(ScalarAnnotation): # type: ignore + """ + Scalar uint29 annotation. + """ + + dtype = UnsignedInteger(29) + + +class uint30(ScalarAnnotation): # type: ignore + """ + Scalar uint30 annotation. + """ + + dtype = UnsignedInteger(30) + + +class uint31(ScalarAnnotation): # type: ignore + """ + Scalar uint31 annotation. + """ + + dtype = UnsignedInteger(31) + + +class uint32(ScalarAnnotation): # type: ignore + """ + Scalar uint32 annotation. + """ + + dtype = UnsignedInteger(32) + + +class uint33(ScalarAnnotation): # type: ignore + """ + Scalar uint33 annotation. + """ + + dtype = UnsignedInteger(33) + + +class uint34(ScalarAnnotation): # type: ignore + """ + Scalar uint34 annotation. + """ + + dtype = UnsignedInteger(34) + + +class uint35(ScalarAnnotation): # type: ignore + """ + Scalar uint35 annotation. + """ + + dtype = UnsignedInteger(35) + + +class uint36(ScalarAnnotation): # type: ignore + """ + Scalar uint36 annotation. + """ + + dtype = UnsignedInteger(36) + + +class uint37(ScalarAnnotation): # type: ignore + """ + Scalar uint37 annotation. + """ + + dtype = UnsignedInteger(37) + + +class uint38(ScalarAnnotation): # type: ignore + """ + Scalar uint38 annotation. + """ + + dtype = UnsignedInteger(38) + + +class uint39(ScalarAnnotation): # type: ignore + """ + Scalar uint39 annotation. + """ + + dtype = UnsignedInteger(39) + + +class uint40(ScalarAnnotation): # type: ignore + """ + Scalar uint40 annotation. + """ + + dtype = UnsignedInteger(40) + + +class uint41(ScalarAnnotation): # type: ignore + """ + Scalar uint41 annotation. + """ + + dtype = UnsignedInteger(41) + + +class uint42(ScalarAnnotation): # type: ignore + """ + Scalar uint42 annotation. + """ + + dtype = UnsignedInteger(42) + + +class uint43(ScalarAnnotation): # type: ignore + """ + Scalar uint43 annotation. + """ + + dtype = UnsignedInteger(43) + + +class uint44(ScalarAnnotation): # type: ignore + """ + Scalar uint44 annotation. + """ + + dtype = UnsignedInteger(44) + + +class uint45(ScalarAnnotation): # type: ignore + """ + Scalar uint45 annotation. + """ + + dtype = UnsignedInteger(45) + + +class uint46(ScalarAnnotation): # type: ignore + """ + Scalar uint46 annotation. + """ + + dtype = UnsignedInteger(46) + + +class uint47(ScalarAnnotation): # type: ignore + """ + Scalar uint47 annotation. + """ + + dtype = UnsignedInteger(47) + + +class uint48(ScalarAnnotation): # type: ignore + """ + Scalar uint48 annotation. + """ + + dtype = UnsignedInteger(48) + + +class uint49(ScalarAnnotation): # type: ignore + """ + Scalar uint49 annotation. + """ + + dtype = UnsignedInteger(49) + + +class uint50(ScalarAnnotation): # type: ignore + """ + Scalar uint50 annotation. + """ + + dtype = UnsignedInteger(50) + + +class uint51(ScalarAnnotation): # type: ignore + """ + Scalar uint51 annotation. + """ + + dtype = UnsignedInteger(51) + + +class uint52(ScalarAnnotation): # type: ignore + """ + Scalar uint52 annotation. + """ + + dtype = UnsignedInteger(52) + + +class uint53(ScalarAnnotation): # type: ignore + """ + Scalar uint53 annotation. + """ + + dtype = UnsignedInteger(53) + + +class uint54(ScalarAnnotation): # type: ignore + """ + Scalar uint54 annotation. + """ + + dtype = UnsignedInteger(54) + + +class uint55(ScalarAnnotation): # type: ignore + """ + Scalar uint55 annotation. + """ + + dtype = UnsignedInteger(55) + + +class uint56(ScalarAnnotation): # type: ignore + """ + Scalar uint56 annotation. + """ + + dtype = UnsignedInteger(56) + + +class uint57(ScalarAnnotation): # type: ignore + """ + Scalar uint57 annotation. + """ + + dtype = UnsignedInteger(57) + + +class uint58(ScalarAnnotation): # type: ignore + """ + Scalar uint58 annotation. + """ + + dtype = UnsignedInteger(58) + + +class uint59(ScalarAnnotation): # type: ignore + """ + Scalar uint59 annotation. + """ + + dtype = UnsignedInteger(59) + + +class uint60(ScalarAnnotation): # type: ignore + """ + Scalar uint60 annotation. + """ + + dtype = UnsignedInteger(60) + + +class uint61(ScalarAnnotation): # type: ignore + """ + Scalar uint61 annotation. + """ + + dtype = UnsignedInteger(61) + + +class uint62(ScalarAnnotation): # type: ignore + """ + Scalar uint62 annotation. + """ + + dtype = UnsignedInteger(62) + + +class uint63(ScalarAnnotation): # type: ignore + """ + Scalar uint63 annotation. + """ + + dtype = UnsignedInteger(63) + + +class uint64(ScalarAnnotation): # type: ignore + """ + Scalar uint64 annotation. + """ + + dtype = UnsignedInteger(64) + + +class tensor(TensorAnnotation): # type: ignore + """ + Tensor annotation. + """ + + def __class_getitem__(cls, item): + if not isinstance(item, tuple): + item = (item,) + + annotation = item[0] + if not issubclass(annotation, ScalarAnnotation): + raise ValueError( + f"First argument to tensor annotations should be a " + f"concrete-numpy data type (e.g., cnp.uint4) " + f"not {annotation.__name__ if hasattr(annotation, '__name__') else str(annotation)}" + ) + + if len(item) == 1: + raise ValueError( + "Tensor annotations should have a shape (e.g., cnp.tensor[cnp.uint4, 3, 2])" + ) + + shape = item[1:] + if not all(isinstance(x, int) for x in shape): + raise ValueError("Tensor annotation shape elements must be 'int'") + + return Value(dtype=annotation.dtype, shape=shape, is_encrypted=False) diff --git a/tests/compilation/test_artifacts.py b/tests/compilation/test_artifacts.py index ffd319d2e..82e436dda 100644 --- a/tests/compilation/test_artifacts.py +++ b/tests/compilation/test_artifacts.py @@ -7,7 +7,7 @@ from pathlib import Path import numpy as np -from concrete.numpy.compilation import DebugArtifacts, compiler +from concrete.numpy import DebugArtifacts, compiler def test_artifacts_export(helpers): diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 19f3c5d2a..1f08e8483 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -8,8 +8,7 @@ from pathlib import Path import numpy as np import pytest -from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server -from concrete.numpy.compilation import compiler +from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server, compiler def test_circuit_str(helpers): diff --git a/tests/compilation/test_decorator.py b/tests/compilation/test_decorator.py deleted file mode 100644 index ee0553830..000000000 --- a/tests/compilation/test_decorator.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Tests of `compiler` decorator. -""" - -from concrete.numpy.compilation import DebugArtifacts, compiler - - -def test_call_compile(helpers): - """ - Test `__call__` and `compile` methods of `compiler` decorator back to back. - """ - - configuration = helpers.configuration() - - @compiler({"x": "encrypted"}) - def function(x): - return x + 42 - - for i in range(10): - function(i) - - circuit = function.compile(configuration=configuration) - - sample = 5 - helpers.check_execution(circuit, function, sample) - - -def test_compiler_verbose_trace(helpers, capsys): - """ - Test `trace` method of `compiler` decorator with verbose flag. - """ - - configuration = helpers.configuration() - artifacts = DebugArtifacts() - - @compiler({"x": "encrypted"}) - def function(x): - return x + 42 - - inputset = range(10) - function.trace(inputset, configuration, artifacts, show_graph=True) - - captured = capsys.readouterr() - assert captured.out.strip() == ( - f""" - -Computation Graph ------------------------------------------------- -{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])} ------------------------------------------------- - - """.strip() - ) - - -def test_compiler_verbose_compile(helpers, capsys): - """ - Test `compile` method of `compiler` decorator with verbose flag. - """ - - configuration = helpers.configuration() - artifacts = DebugArtifacts() - - @compiler({"x": "encrypted"}) - def function(x): - return x + 42 - - inputset = range(10) - function.compile(inputset, configuration, artifacts, verbose=True) - - captured = capsys.readouterr() - assert captured.out.strip().startswith( - f""" - -Computation Graph --------------------------------------------------------------------------------- -{list(artifacts.textual_representations_of_graphs.values())[-1][-1]} --------------------------------------------------------------------------------- - -MLIR --------------------------------------------------------------------------------- -{artifacts.mlir_to_compile} --------------------------------------------------------------------------------- - -Optimizer --------------------------------------------------------------------------------- - - """.strip() - ) - - -def test_compiler_verbose_virtual_compile(helpers, capsys): - """ - Test `compile` method of `compiler` decorator with verbose flag. - """ - - configuration = helpers.configuration() - artifacts = DebugArtifacts() - - @compiler({"x": "encrypted"}) - def function(x): - return x + 42 - - inputset = range(10) - function.compile(inputset, configuration, artifacts, verbose=True, virtual=True) - - captured = capsys.readouterr() - assert captured.out.strip() == ( - f""" - -Computation Graph ------------------------------------------------- -{list(artifacts.textual_representations_of_graphs.values())[-1][-1]} ------------------------------------------------- - -MLIR ------------------------------------------------- -Virtual circuits don't have MLIR. ------------------------------------------------- - -Optimizer ------------------------------------------------- -Virtual circuits don't have optimizer output. ------------------------------------------------- - - """.strip() - ) diff --git a/tests/compilation/test_decorators.py b/tests/compilation/test_decorators.py new file mode 100644 index 000000000..a8790c7d1 --- /dev/null +++ b/tests/compilation/test_decorators.py @@ -0,0 +1,280 @@ +""" +Tests of `compiler` and `circuit` decorators. +""" + +import numpy as np +import pytest + +import concrete.numpy as cnp + + +def test_compiler_call_and_compile(helpers): + """ + Test `__call__` and `compile` methods of `compiler` decorator back to back. + """ + + configuration = helpers.configuration() + + @cnp.compiler({"x": "encrypted"}) + def function(x): + return x + 42 + + for i in range(10): + function(i) + + circuit = function.compile(configuration=configuration) + + sample = 5 + helpers.check_execution(circuit, function, sample) + + +def test_compiler_verbose_trace(helpers, capsys): + """ + Test `trace` method of `compiler` decorator with verbose flag. + """ + + configuration = helpers.configuration() + artifacts = cnp.DebugArtifacts() + + @cnp.compiler({"x": "encrypted"}) + def function(x): + return x + 42 + + inputset = range(10) + function.trace(inputset, configuration, artifacts, show_graph=True) + + captured = capsys.readouterr() + assert captured.out.strip() == ( + f""" + +Computation Graph +------------------------------------------------ +{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])} +------------------------------------------------ + + """.strip() + ) + + +def test_compiler_verbose_compile(helpers, capsys): + """ + Test `compile` method of `compiler` decorator with verbose flag. + """ + + configuration = helpers.configuration() + artifacts = cnp.DebugArtifacts() + + @cnp.compiler({"x": "encrypted"}) + def function(x): + return x + 42 + + inputset = range(10) + function.compile(inputset, configuration, artifacts, verbose=True) + + captured = capsys.readouterr() + assert captured.out.strip().startswith( + f""" + +Computation Graph +-------------------------------------------------------------------------------- +{list(artifacts.textual_representations_of_graphs.values())[-1][-1]} +-------------------------------------------------------------------------------- + +MLIR +-------------------------------------------------------------------------------- +{artifacts.mlir_to_compile} +-------------------------------------------------------------------------------- + +Optimizer +-------------------------------------------------------------------------------- + + """.strip() + ) + + +def test_compiler_verbose_virtual_compile(helpers, capsys): + """ + Test `compile` method of `compiler` decorator with verbose flag. + """ + + configuration = helpers.configuration() + artifacts = cnp.DebugArtifacts() + + @cnp.compiler({"x": "encrypted"}) + def function(x): + return x + 42 + + inputset = range(10) + function.compile(inputset, configuration, artifacts, verbose=True, virtual=True) + + captured = capsys.readouterr() + assert captured.out.strip() == ( + f""" + +Computation Graph +------------------------------------------------ +{list(artifacts.textual_representations_of_graphs.values())[-1][-1]} +------------------------------------------------ + +MLIR +------------------------------------------------ +Virtual circuits don't have MLIR. +------------------------------------------------ + +Optimizer +------------------------------------------------ +Virtual circuits don't have optimizer output. +------------------------------------------------ + + """.strip() + ) + + +def test_circuit(helpers): + """ + Test circuit decorator. + """ + + @cnp.circuit({"x": "encrypted"}, helpers.configuration()) + def circuit1(x: cnp.uint2): + return x + 42 + + helpers.check_str( + str(circuit1), + """ + +%0 = x # EncryptedScalar +%1 = 42 # ClearScalar +%2 = add(%0, %1) # EncryptedScalar +return %2 + + """.strip(), + ) + + # ====================================================================== + + @cnp.circuit({"x": "encrypted"}, helpers.configuration()) + def circuit2(x: cnp.tensor[cnp.uint2, 3, 2]): + return x + 42 + + helpers.check_str( + str(circuit2), + """ + +%0 = x # EncryptedTensor +%1 = 42 # ClearScalar +%2 = add(%0, %1) # EncryptedTensor +return %2 + + """.strip(), + ) + + # ====================================================================== + + @cnp.circuit({"x": "encrypted"}, helpers.configuration()) + def circuit3(x: cnp.uint3): + def square(x): + return x**2 + + return cnp.univariate(square, outputs=cnp.uint7)(x) + + helpers.check_str( + str(circuit3), + """ + +%0 = x # EncryptedScalar +%1 = square(%0) # EncryptedScalar +return %1 + + """.strip(), + ) + + # ====================================================================== + + @cnp.circuit({"x": "encrypted"}, helpers.configuration()) + def circuit4(x: cnp.uint3): + return ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(cnp.uint3) + + helpers.check_str( + str(circuit4), + """ + +%0 = x # EncryptedScalar +%1 = subgraph(%0) # EncryptedScalar +return %1 + +Subgraphs: + + %1 = subgraph(%0): + + %0 = 2 # ClearScalar + %1 = 2 # ClearScalar + %2 = input # EncryptedScalar + %3 = sin(%2) # EncryptedScalar + %4 = cos(%2) # EncryptedScalar + %5 = power(%3, %0) # EncryptedScalar + %6 = power(%4, %1) # EncryptedScalar + %7 = add(%5, %6) # EncryptedScalar + %8 = astype(%7) # EncryptedScalar + return %8 + + """.strip(), + ) + + +def test_bad_circuit(helpers): + """ + Test circuit decorator with bad parameters. + """ + + # bad annotation + # -------------- + + with pytest.raises(ValueError) as excinfo: + + @cnp.circuit({"x": "encrypted"}, helpers.configuration()) + def circuit1(x: int): + return x + 42 + + assert str(excinfo.value) == ( + f"Annotation {str(int)} for argument 'x' is not valid " + f"(please use a cnp type such as `cnp.uint4` or 'cnp.tensor[cnp.uint4, 3, 2]')" + ) + + # missing encryption status + # ------------------------- + + with pytest.raises(ValueError) as excinfo: + + @cnp.circuit({}, helpers.configuration()) + def circuit2(x: cnp.uint3): + return x + 42 + + assert str(excinfo.value) == ( + "Encryption status of parameter 'x' of function 'circuit2' is not provided" + ) + + # bad astype + # ---------- + with pytest.raises(ValueError) as excinfo: + + @cnp.circuit({"x": "encrypted"}, helpers.configuration()) + def circuit3(x: cnp.uint3): + return x.astype(np.int64) + + assert str(excinfo.value) == ( + "`astype` method must be called with a concrete.numpy type " + "for direct circuit definition (e.g., value.astype(cnp.uint4))" + ) + + # round + # ----- + with pytest.raises(RuntimeError) as excinfo: + + @cnp.circuit({"x": "encrypted"}, helpers.configuration()) + def circuit4(x: cnp.uint3): + return round(x) + + assert str(excinfo.value) == ( + "'round(x)' cannot be used in direct definition (you may use np.around instead)" + ) diff --git a/tests/extensions/test_array.py b/tests/extensions/test_array.py index de3ba7e6f..1489d1552 100644 --- a/tests/extensions/test_array.py +++ b/tests/extensions/test_array.py @@ -1,5 +1,5 @@ """ -Tests of LookupTable. +Tests of 'array' extension. """ import pytest diff --git a/tests/extensions/test_table.py b/tests/extensions/test_table.py index 188311994..46912e35a 100644 --- a/tests/extensions/test_table.py +++ b/tests/extensions/test_table.py @@ -1,5 +1,5 @@ """ -Tests of LookupTable. +Tests of 'LookupTable' extension. """ import pytest diff --git a/tests/extensions/test_univariate.py b/tests/extensions/test_univariate.py new file mode 100644 index 000000000..e60bb57f4 --- /dev/null +++ b/tests/extensions/test_univariate.py @@ -0,0 +1,24 @@ +""" +Tests of 'univariate' extension. +""" + +import pytest + +import concrete.numpy as cnp + + +def test_bad_univariate(helpers): + """ + Test 'univariate' extension with bad parameters. + """ + + with pytest.raises(ValueError) as excinfo: + + @cnp.circuit({"x": "encrypted"}, helpers.configuration()) + def function(x: cnp.uint3): + return cnp.univariate(lambda x: x**2)(x) + + assert str(excinfo.value) == ( + "Univariate extension requires `outputs` argument for direct circuit definition " + "(e.g., cnp.univariate(function, outputs=cnp.uint4)(x))" + ) diff --git a/tests/tracing/test_tracer.py b/tests/tracing/test_tracer.py index 82a3298c8..25dc17ea0 100644 --- a/tests/tracing/test_tracer.py +++ b/tests/tracing/test_tracer.py @@ -7,6 +7,7 @@ import pytest from concrete.numpy.dtypes import UnsignedInteger from concrete.numpy.tracing import Tracer +from concrete.numpy.tracing.typing import uint4 from concrete.numpy.values import EncryptedTensor @@ -43,6 +44,13 @@ from concrete.numpy.values import EncryptedTensor RuntimeError, "Only __call__ hook is supported for numpy ufuncs", ), + pytest.param( + lambda x: x.astype(uint4), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(4,))}, + ValueError, + "`astype` method must be called with a " + "numpy type for compilation (e.g., value.astype(np.int64))", + ), ], ) def test_tracer_bad_trace(function, parameters, expected_error, expected_message): diff --git a/tests/tracing/test_typing.py b/tests/tracing/test_typing.py new file mode 100644 index 000000000..d0d0bc33b --- /dev/null +++ b/tests/tracing/test_typing.py @@ -0,0 +1,54 @@ +""" +Test type annotations. +""" + +import pytest + +import concrete.numpy as cnp + + +def test_bad_tensor(): + """ + Test `tensor` type with bad parameters + """ + + # invalid dtype + # ------------- + + with pytest.raises(ValueError) as excinfo: + + def case1(x: cnp.tensor[int]): + return x + + case1(None) + + assert str(excinfo.value) == ( + "First argument to tensor annotations should be a " + "concrete-numpy data type (e.g., cnp.uint4) not int" + ) + + # no shape + # -------- + + with pytest.raises(ValueError) as excinfo: + + def case2(x: cnp.tensor[cnp.uint3]): + return x + + case2(None) + + assert str(excinfo.value) == ( + "Tensor annotations should have a shape (e.g., cnp.tensor[cnp.uint4, 3, 2])" + ) + + # bad shape + # --------- + + with pytest.raises(ValueError) as excinfo: + + def case3(x: cnp.tensor[cnp.uint3, 1.5]): + return x + + case3(None) + + assert str(excinfo.value) == "Tensor annotation shape elements must be 'int'"