feat: introduce circuit decorator to directly define circuits

This commit is contained in:
Umut
2022-10-07 10:55:03 +02:00
parent c5e43616a5
commit 66c707cd69
17 changed files with 1950 additions and 155 deletions

View File

@@ -13,8 +13,141 @@ from .compilation import (
DebugArtifacts,
EncryptionStatus,
Server,
compiler,
)
from .compilation.decorators import circuit, compiler
from .extensions import LookupTable, array, one, ones, univariate, zero, zeros
from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from .representation import Graph
from .tracing.typing import (
f32,
f64,
int1,
int2,
int3,
int4,
int5,
int6,
int7,
int8,
int9,
int10,
int11,
int12,
int13,
int14,
int15,
int16,
int17,
int18,
int19,
int20,
int21,
int22,
int23,
int24,
int25,
int26,
int27,
int28,
int29,
int30,
int31,
int32,
int33,
int34,
int35,
int36,
int37,
int38,
int39,
int40,
int41,
int42,
int43,
int44,
int45,
int46,
int47,
int48,
int49,
int50,
int51,
int52,
int53,
int54,
int55,
int56,
int57,
int58,
int59,
int60,
int61,
int62,
int63,
int64,
tensor,
uint1,
uint2,
uint3,
uint4,
uint5,
uint6,
uint7,
uint8,
uint9,
uint10,
uint11,
uint12,
uint13,
uint14,
uint15,
uint16,
uint17,
uint18,
uint19,
uint20,
uint21,
uint22,
uint23,
uint24,
uint25,
uint26,
uint27,
uint28,
uint29,
uint30,
uint31,
uint32,
uint33,
uint34,
uint35,
uint36,
uint37,
uint38,
uint39,
uint40,
uint41,
uint42,
uint43,
uint44,
uint45,
uint46,
uint47,
uint48,
uint49,
uint50,
uint51,
uint52,
uint53,
uint54,
uint55,
uint56,
uint57,
uint58,
uint59,
uint60,
uint61,
uint62,
uint63,
uint64,
)

View File

@@ -7,6 +7,5 @@ from .circuit import Circuit
from .client import Client
from .compiler import Compiler, EncryptionStatus
from .configuration import Configuration
from .decorator import compiler
from .server import Server
from .specs import ClientSpecs

View File

@@ -45,6 +45,56 @@ class Compiler:
inputset: List[Any]
graph: Optional[Graph]
_is_direct: bool
_parameter_values: Dict[str, Value]
@staticmethod
def assemble(
function: Callable,
parameter_values: Dict[str, Value],
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
) -> Circuit:
"""
Assemble a circuit from the raw parameter values, used in direct circuit definition.
Args:
function (Callable):
function to convert to a circuit
parameter_values (Dict[str, Value]):
parameter values of the function
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
Returns:
Circuit:
assembled circuit
"""
compiler = Compiler(
function,
{
name: "encrypted" if value.is_encrypted else "clear"
for name, value in parameter_values.items()
},
)
# pylint: disable=protected-access
compiler._is_direct = True
compiler._parameter_values = parameter_values
# pylint: enable=protected-access
return compiler.compile(None, configuration, artifacts, **kwargs)
def __init__(
self,
function: Callable,
@@ -102,6 +152,9 @@ class Compiler:
self.inputset = []
self.graph = None
self._is_direct = False
self._parameter_values = {}
def __call__(
self,
*args: Any,
@@ -171,6 +224,18 @@ class Compiler:
optional inputset to extend accumulated inputset before bounds measurement
"""
if self._is_direct:
self.graph = Tracer.trace(self.function, self._parameter_values, is_direct=True)
if self.artifacts is not None:
self.artifacts.add_graph("initial", self.graph) # pragma: no cover
fuse(self.graph, self.artifacts)
if self.artifacts is not None:
self.artifacts.add_graph("final", self.graph) # pragma: no cover
return
if inputset is not None:
previous_inputset_length = len(self.inputset)
for index, sample in enumerate(iter(inputset)):

View File

@@ -1,22 +1,82 @@
"""
Declaration of `compiler` decorator.
Declaration of `circuit` and `compiler` decorators.
"""
from typing import Any, Callable, Iterable, Mapping, Optional, Tuple, Union
import inspect
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
from ..representation import Graph
from ..tracing.typing import ScalarAnnotation
from ..values import Value
from .artifacts import DebugArtifacts
from .circuit import Circuit
from .compiler import Compiler, EncryptionStatus
from .configuration import Configuration
def compiler(parameters: Mapping[str, EncryptionStatus]):
def circuit(
parameters: Mapping[str, Union[str, EncryptionStatus]],
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
**kwargs,
):
"""
Provide a direct interface for compilation.
Args:
parameters (Mapping[str, Union[str, EncryptionStatus]]):
encryption statuses of the parameters of the function to compile
configuration(Optional[Configuration], default = None):
configuration to use
artifacts (Optional[DebugArtifacts], default = None):
artifacts to store information about the process
kwargs (Dict[str, Any]):
configuration options to overwrite
"""
def decoration(function: Callable):
signature = inspect.signature(function)
parameter_values: Dict[str, Value] = {}
for name, details in signature.parameters.items():
if name not in parameters:
continue
annotation = details.annotation
is_value = isinstance(annotation, Value)
is_scalar_annotation = isinstance(annotation, type) and issubclass(
annotation, ScalarAnnotation
)
if not (is_value or is_scalar_annotation):
raise ValueError(
f"Annotation {annotation} for argument '{name}' is not valid "
f"(please use a cnp type such as "
f"`cnp.uint4` or 'cnp.tensor[cnp.uint4, 3, 2]')"
)
parameter_values[name] = (
annotation if is_value else Value(annotation.dtype, shape=(), is_encrypted=False)
)
status = EncryptionStatus(parameters[name].lower())
parameter_values[name].is_encrypted = status == "encrypted"
return Compiler.assemble(function, parameter_values, configuration, artifacts, **kwargs)
return decoration
def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]):
"""
Provide an easy interface for compilation.
Args:
parameters (Dict[str, EncryptionStatus]):
parameters (Mapping[str, Union[str, EncryptionStatus]]):
encryption statuses of the parameters of the function to compile
"""

View File

@@ -2,18 +2,19 @@
Declaration of `univariate` function.
"""
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Type, Union
import numpy as np
from ..dtypes import Float
from ..dtypes import BaseDataType, Float
from ..representation import Node
from ..tracing import Tracer
from ..tracing import ScalarAnnotation, Tracer
from ..values import Value
def univariate(
function: Callable[[Any], Any],
outputs: Optional[Union[BaseDataType, Type[ScalarAnnotation]]] = None,
) -> Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
"""
Wrap a univariate function so that it is traced into a single generic node.
@@ -22,6 +23,9 @@ def univariate(
function (Callable[[Any], Any]):
univariate function to wrap
outputs (Optional[Union[BaseDataType, Type[ScalarAnnotation]]], default = None):
data type of the result, unused during compilation, required for direct definition
Returns:
Callable[[Union[Tracer, Any]], Union[Tracer, Any]]:
another univariate function that can be called with a Tracer as well
@@ -57,6 +61,19 @@ def univariate(
if output_value.shape != x.output.shape:
raise ValueError(f"Function {function.__name__} cannot be used with cnp.univariate")
# pylint: disable=protected-access
is_direct = Tracer._is_direct
# pylint: enable=protected-access
if is_direct:
if outputs is None:
raise ValueError(
"Univariate extension requires "
"`outputs` argument for direct circuit definition "
"(e.g., cnp.univariate(function, outputs=cnp.uint4)(x))"
)
output_value.dtype = outputs if isinstance(outputs, BaseDataType) else outputs.dtype
computation = Node.generic(
function.__name__,
[x.output],

View File

@@ -2,4 +2,4 @@
Provide `function` to `computation graph` functionality.
"""
from .tracer import Tracer
from .tracer import ScalarAnnotation, TensorAnnotation, Tracer

View File

@@ -4,13 +4,13 @@ Declaration of `Tracer` class.
import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import networkx as nx
import numpy as np
from numpy.typing import DTypeLike
from ..dtypes import Float, Integer
from ..dtypes import BaseDataType, Float, Integer
from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..representation.utils import format_indexing_element
@@ -29,13 +29,12 @@ class Tracer:
# property to keep track of assignments
last_version: Optional["Tracer"] = None
# variable to control the behavior of __eq__
# so that it can be traced but still allow
# using Tracers in dicts when not tracing
# variables to control the behavior of certain functions
_is_tracing: bool = False
_is_direct: bool = False
@staticmethod
def trace(function: Callable, parameters: Dict[str, Value]) -> Graph:
def trace(function: Callable, parameters: Dict[str, Value], is_direct: bool = False) -> Graph:
"""
Trace `function` and create the `Graph` that represents it.
@@ -47,6 +46,9 @@ class Tracer:
parameters of function to trace
e.g. parameter x is an EncryptedScalar holding a 7-bit UnsignedInteger
is_direct (bool, default = False):
whether the tracing is done on actual parameters or placeholders
Returns:
Graph:
computation graph corresponding to `function`
@@ -67,6 +69,8 @@ class Tracer:
arguments[param] = Tracer(node, [])
input_indices[node] = index
Tracer._is_direct = is_direct
Tracer._is_tracing = True
output_tracers: Any = function(**arguments)
Tracer._is_tracing = False
@@ -383,6 +387,13 @@ class Tracer:
output_value = Value.of(evaluation)
output_value.is_encrypted = any(tracer.output.is_encrypted for tracer in tracers)
if Tracer._is_direct and isinstance(output_value.dtype, Integer):
resulting_bit_width = 0
for tracer in tracers:
assert isinstance(tracer.output.dtype, Integer)
resulting_bit_width = max(resulting_bit_width, tracer.output.dtype.bit_width)
output_value.dtype.bit_width = resulting_bit_width
computation = Node.generic(
operation.__name__,
[tracer.output for tracer in tracers],
@@ -488,7 +499,12 @@ class Tracer:
def __round__(self, ndigits=None):
if ndigits is None:
return Tracer._trace_numpy_operation(np.around, self).astype(np.int64)
result = Tracer._trace_numpy_operation(np.around, self)
if self._is_direct:
raise RuntimeError(
"'round(x)' cannot be used in direct definition (you may use np.around instead)"
)
return result.astype(np.int64)
return Tracer._trace_numpy_operation(np.around, self, decimals=ndigits)
@@ -551,13 +567,38 @@ class Tracer:
else Tracer._trace_numpy_operation(np.not_equal, self, self.sanitize(other))
)
def astype(self, dtype: DTypeLike) -> "Tracer":
def astype(self, dtype: Union[DTypeLike, Type["ScalarAnnotation"]]) -> "Tracer":
"""
Trace numpy.ndarray.astype(dtype).
"""
normalized_dtype = np.dtype(dtype)
if np.issubdtype(normalized_dtype, np.integer) and normalized_dtype != np.int64:
if Tracer._is_direct:
output_value = deepcopy(self.output)
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
output_value.dtype = dtype.dtype
else:
raise ValueError(
"`astype` method must be called with a concrete.numpy type "
"for direct circuit definition (e.g., value.astype(cnp.uint4))"
)
computation = Node.generic(
"astype",
[self.output],
output_value,
lambda x: x, # unused for direct definition
)
return Tracer(computation, [self])
if isinstance(dtype, type) and issubclass(dtype, ScalarAnnotation):
raise ValueError(
"`astype` method must be called with a "
"numpy type for compilation (e.g., value.astype(np.int64))"
)
dtype = np.dtype(dtype).type
if np.issubdtype(dtype, np.integer) and dtype != np.int64:
print(
"Warning: When using `value.astype(newtype)` "
"with an integer newtype, "
@@ -567,9 +608,9 @@ class Tracer:
)
output_value = deepcopy(self.output)
output_value.dtype = Value.of(normalized_dtype.type(0)).dtype
output_value.dtype = Value.of(dtype(0)).dtype # type: ignore
if np.issubdtype(normalized_dtype.type, np.integer):
if np.issubdtype(dtype, np.integer):
def evaluator(x, dtype):
if np.any(np.isnan(x)):
@@ -588,7 +629,7 @@ class Tracer:
[self.output],
output_value,
evaluator,
kwargs={"dtype": normalized_dtype.type},
kwargs={"dtype": dtype},
)
return Tracer(computation, [self])
@@ -765,3 +806,23 @@ class Tracer:
"""
return Tracer._trace_numpy_operation(np.transpose, self)
class Annotation(Tracer):
"""
Base annotation for direct definition.
"""
class ScalarAnnotation(Annotation):
"""
Base scalar annotation for direct definition.
"""
dtype: BaseDataType
class TensorAnnotation(Annotation):
"""
Base tensor annotation for direct definition.
"""

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ from pathlib import Path
import numpy as np
from concrete.numpy.compilation import DebugArtifacts, compiler
from concrete.numpy import DebugArtifacts, compiler
def test_artifacts_export(helpers):

View File

@@ -8,8 +8,7 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server
from concrete.numpy.compilation import compiler
from concrete.numpy import Client, ClientSpecs, EvaluationKeys, Server, compiler
def test_circuit_str(helpers):

View File

@@ -1,127 +0,0 @@
"""
Tests of `compiler` decorator.
"""
from concrete.numpy.compilation import DebugArtifacts, compiler
def test_call_compile(helpers):
"""
Test `__call__` and `compile` methods of `compiler` decorator back to back.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted"})
def function(x):
return x + 42
for i in range(10):
function(i)
circuit = function.compile(configuration=configuration)
sample = 5
helpers.check_execution(circuit, function, sample)
def test_compiler_verbose_trace(helpers, capsys):
"""
Test `trace` method of `compiler` decorator with verbose flag.
"""
configuration = helpers.configuration()
artifacts = DebugArtifacts()
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
function.trace(inputset, configuration, artifacts, show_graph=True)
captured = capsys.readouterr()
assert captured.out.strip() == (
f"""
Computation Graph
------------------------------------------------
{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])}
------------------------------------------------
""".strip()
)
def test_compiler_verbose_compile(helpers, capsys):
"""
Test `compile` method of `compiler` decorator with verbose flag.
"""
configuration = helpers.configuration()
artifacts = DebugArtifacts()
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
function.compile(inputset, configuration, artifacts, verbose=True)
captured = capsys.readouterr()
assert captured.out.strip().startswith(
f"""
Computation Graph
--------------------------------------------------------------------------------
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
--------------------------------------------------------------------------------
MLIR
--------------------------------------------------------------------------------
{artifacts.mlir_to_compile}
--------------------------------------------------------------------------------
Optimizer
--------------------------------------------------------------------------------
""".strip()
)
def test_compiler_verbose_virtual_compile(helpers, capsys):
"""
Test `compile` method of `compiler` decorator with verbose flag.
"""
configuration = helpers.configuration()
artifacts = DebugArtifacts()
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
function.compile(inputset, configuration, artifacts, verbose=True, virtual=True)
captured = capsys.readouterr()
assert captured.out.strip() == (
f"""
Computation Graph
------------------------------------------------
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
------------------------------------------------
MLIR
------------------------------------------------
Virtual circuits don't have MLIR.
------------------------------------------------
Optimizer
------------------------------------------------
Virtual circuits don't have optimizer output.
------------------------------------------------
""".strip()
)

View File

@@ -0,0 +1,280 @@
"""
Tests of `compiler` and `circuit` decorators.
"""
import numpy as np
import pytest
import concrete.numpy as cnp
def test_compiler_call_and_compile(helpers):
"""
Test `__call__` and `compile` methods of `compiler` decorator back to back.
"""
configuration = helpers.configuration()
@cnp.compiler({"x": "encrypted"})
def function(x):
return x + 42
for i in range(10):
function(i)
circuit = function.compile(configuration=configuration)
sample = 5
helpers.check_execution(circuit, function, sample)
def test_compiler_verbose_trace(helpers, capsys):
"""
Test `trace` method of `compiler` decorator with verbose flag.
"""
configuration = helpers.configuration()
artifacts = cnp.DebugArtifacts()
@cnp.compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
function.trace(inputset, configuration, artifacts, show_graph=True)
captured = capsys.readouterr()
assert captured.out.strip() == (
f"""
Computation Graph
------------------------------------------------
{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])}
------------------------------------------------
""".strip()
)
def test_compiler_verbose_compile(helpers, capsys):
"""
Test `compile` method of `compiler` decorator with verbose flag.
"""
configuration = helpers.configuration()
artifacts = cnp.DebugArtifacts()
@cnp.compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
function.compile(inputset, configuration, artifacts, verbose=True)
captured = capsys.readouterr()
assert captured.out.strip().startswith(
f"""
Computation Graph
--------------------------------------------------------------------------------
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
--------------------------------------------------------------------------------
MLIR
--------------------------------------------------------------------------------
{artifacts.mlir_to_compile}
--------------------------------------------------------------------------------
Optimizer
--------------------------------------------------------------------------------
""".strip()
)
def test_compiler_verbose_virtual_compile(helpers, capsys):
"""
Test `compile` method of `compiler` decorator with verbose flag.
"""
configuration = helpers.configuration()
artifacts = cnp.DebugArtifacts()
@cnp.compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
function.compile(inputset, configuration, artifacts, verbose=True, virtual=True)
captured = capsys.readouterr()
assert captured.out.strip() == (
f"""
Computation Graph
------------------------------------------------
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
------------------------------------------------
MLIR
------------------------------------------------
Virtual circuits don't have MLIR.
------------------------------------------------
Optimizer
------------------------------------------------
Virtual circuits don't have optimizer output.
------------------------------------------------
""".strip()
)
def test_circuit(helpers):
"""
Test circuit decorator.
"""
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
def circuit1(x: cnp.uint2):
return x + 42
helpers.check_str(
str(circuit1),
"""
%0 = x # EncryptedScalar<uint2>
%1 = 42 # ClearScalar<uint6>
%2 = add(%0, %1) # EncryptedScalar<uint6>
return %2
""".strip(),
)
# ======================================================================
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
def circuit2(x: cnp.tensor[cnp.uint2, 3, 2]):
return x + 42
helpers.check_str(
str(circuit2),
"""
%0 = x # EncryptedTensor<uint2, shape=(3, 2)>
%1 = 42 # ClearScalar<uint6>
%2 = add(%0, %1) # EncryptedTensor<uint6, shape=(3, 2)>
return %2
""".strip(),
)
# ======================================================================
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
def circuit3(x: cnp.uint3):
def square(x):
return x**2
return cnp.univariate(square, outputs=cnp.uint7)(x)
helpers.check_str(
str(circuit3),
"""
%0 = x # EncryptedScalar<uint3>
%1 = square(%0) # EncryptedScalar<uint7>
return %1
""".strip(),
)
# ======================================================================
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
def circuit4(x: cnp.uint3):
return ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(cnp.uint3)
helpers.check_str(
str(circuit4),
"""
%0 = x # EncryptedScalar<uint3>
%1 = subgraph(%0) # EncryptedScalar<uint3>
return %1
Subgraphs:
%1 = subgraph(%0):
%0 = 2 # ClearScalar<uint2>
%1 = 2 # ClearScalar<uint2>
%2 = input # EncryptedScalar<uint3>
%3 = sin(%2) # EncryptedScalar<float64>
%4 = cos(%2) # EncryptedScalar<float64>
%5 = power(%3, %0) # EncryptedScalar<float64>
%6 = power(%4, %1) # EncryptedScalar<float64>
%7 = add(%5, %6) # EncryptedScalar<float64>
%8 = astype(%7) # EncryptedScalar<uint3>
return %8
""".strip(),
)
def test_bad_circuit(helpers):
"""
Test circuit decorator with bad parameters.
"""
# bad annotation
# --------------
with pytest.raises(ValueError) as excinfo:
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
def circuit1(x: int):
return x + 42
assert str(excinfo.value) == (
f"Annotation {str(int)} for argument 'x' is not valid "
f"(please use a cnp type such as `cnp.uint4` or 'cnp.tensor[cnp.uint4, 3, 2]')"
)
# missing encryption status
# -------------------------
with pytest.raises(ValueError) as excinfo:
@cnp.circuit({}, helpers.configuration())
def circuit2(x: cnp.uint3):
return x + 42
assert str(excinfo.value) == (
"Encryption status of parameter 'x' of function 'circuit2' is not provided"
)
# bad astype
# ----------
with pytest.raises(ValueError) as excinfo:
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
def circuit3(x: cnp.uint3):
return x.astype(np.int64)
assert str(excinfo.value) == (
"`astype` method must be called with a concrete.numpy type "
"for direct circuit definition (e.g., value.astype(cnp.uint4))"
)
# round
# -----
with pytest.raises(RuntimeError) as excinfo:
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
def circuit4(x: cnp.uint3):
return round(x)
assert str(excinfo.value) == (
"'round(x)' cannot be used in direct definition (you may use np.around instead)"
)

View File

@@ -1,5 +1,5 @@
"""
Tests of LookupTable.
Tests of 'array' extension.
"""
import pytest

View File

@@ -1,5 +1,5 @@
"""
Tests of LookupTable.
Tests of 'LookupTable' extension.
"""
import pytest

View File

@@ -0,0 +1,24 @@
"""
Tests of 'univariate' extension.
"""
import pytest
import concrete.numpy as cnp
def test_bad_univariate(helpers):
"""
Test 'univariate' extension with bad parameters.
"""
with pytest.raises(ValueError) as excinfo:
@cnp.circuit({"x": "encrypted"}, helpers.configuration())
def function(x: cnp.uint3):
return cnp.univariate(lambda x: x**2)(x)
assert str(excinfo.value) == (
"Univariate extension requires `outputs` argument for direct circuit definition "
"(e.g., cnp.univariate(function, outputs=cnp.uint4)(x))"
)

View File

@@ -7,6 +7,7 @@ import pytest
from concrete.numpy.dtypes import UnsignedInteger
from concrete.numpy.tracing import Tracer
from concrete.numpy.tracing.typing import uint4
from concrete.numpy.values import EncryptedTensor
@@ -43,6 +44,13 @@ from concrete.numpy.values import EncryptedTensor
RuntimeError,
"Only __call__ hook is supported for numpy ufuncs",
),
pytest.param(
lambda x: x.astype(uint4),
{"x": EncryptedTensor(UnsignedInteger(7), shape=(4,))},
ValueError,
"`astype` method must be called with a "
"numpy type for compilation (e.g., value.astype(np.int64))",
),
],
)
def test_tracer_bad_trace(function, parameters, expected_error, expected_message):

View File

@@ -0,0 +1,54 @@
"""
Test type annotations.
"""
import pytest
import concrete.numpy as cnp
def test_bad_tensor():
"""
Test `tensor` type with bad parameters
"""
# invalid dtype
# -------------
with pytest.raises(ValueError) as excinfo:
def case1(x: cnp.tensor[int]):
return x
case1(None)
assert str(excinfo.value) == (
"First argument to tensor annotations should be a "
"concrete-numpy data type (e.g., cnp.uint4) not int"
)
# no shape
# --------
with pytest.raises(ValueError) as excinfo:
def case2(x: cnp.tensor[cnp.uint3]):
return x
case2(None)
assert str(excinfo.value) == (
"Tensor annotations should have a shape (e.g., cnp.tensor[cnp.uint4, 3, 2])"
)
# bad shape
# ---------
with pytest.raises(ValueError) as excinfo:
def case3(x: cnp.tensor[cnp.uint3, 1.5]):
return x
case3(None)
assert str(excinfo.value) == "Tensor annotation shape elements must be 'int'"