feat: create array extension

This commit is contained in:
Umut
2022-07-12 10:22:04 +02:00
parent 5bc0ff42e1
commit 078512d55d
10 changed files with 379 additions and 83 deletions

View File

@@ -15,6 +15,6 @@ from .compilation import (
Server,
compiler,
)
from .extensions import LookupTable, one, ones, univariate, zero, zeros
from .extensions import LookupTable, array, one, ones, univariate, zero, zeros
from .mlir.utils import MAXIMUM_BIT_WIDTH
from .representation import Graph

View File

@@ -2,6 +2,7 @@
Provide additional features that are not present in numpy.
"""
from .array import array
from .ones import one, ones
from .table import LookupTable
from .univariate import univariate

View File

@@ -0,0 +1,58 @@
"""
Declaration of `array` function, to simplify creation of encrypted arrays.
"""
from typing import Any, Union
import numpy as np
from ..dtypes.utils import combine_dtypes
from ..representation import Node
from ..tracing import Tracer
from ..values import Value
def array(values: Any) -> Union[np.ndarray, Tracer]:
"""
Create an encrypted array from either encrypted or clear values.
Args:
values (Any):
array like object compatible with numpy to construct the resulting encrypted array
Returns:
Union[np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
ndarray with values otherwise
"""
# pylint: disable=protected-access
is_tracing = Tracer._is_tracing
# pylint: enable=protected-access
if not isinstance(values, np.ndarray):
values = np.array(values)
if not is_tracing:
return values
shape = values.shape
values = values.flatten()
for i, value in enumerate(values):
if not isinstance(value, Tracer):
values[i] = Tracer.sanitize(value)
if not values[i].output.is_scalar:
raise ValueError("Encrypted arrays can only be created from scalars")
dtype = combine_dtypes([value.output.dtype for value in values])
is_encrypted = True
computation = Node.generic(
"array",
[value.output for value in values],
Value(dtype, shape, is_encrypted),
lambda *args: np.array(args).reshape(shape),
)
return Tracer(computation, values)

View File

@@ -83,6 +83,10 @@ class GraphConverter:
if name == "add":
assert_that(len(inputs) == 2)
elif name == "array":
assert_that(len(inputs) > 0)
assert_that(all(input.is_scalar for input in inputs))
elif name == "concatenate":
if not all(input.is_encrypted for input in inputs):
return "only all encrypted concatenate is supported"
@@ -416,6 +420,9 @@ class GraphConverter:
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
# { "%0": "tensor.from_elements ..." } == we need to convert the part after "=" for %0
direct_replacements: Dict[str, str] = {}
with Context() as ctx, Location.unknown():
concretelang.register_dialects(ctx)
@@ -455,6 +462,7 @@ class GraphConverter:
nodes_to_mlir_names,
mlir_names_to_mlir_types,
scalar_to_1d_tensor_conversion_hacks,
direct_replacements,
)
ir_to_mlir[node] = node_converter.convert()
@@ -464,6 +472,12 @@ class GraphConverter:
module_lines_after_hacks_are_applied = []
for line in str(module).split("\n"):
mlir_name = line.split("=")[0].strip()
if mlir_name in direct_replacements:
new_value = direct_replacements[mlir_name]
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
continue
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
module_lines_after_hacks_are_applied.append(line)
continue

View File

@@ -2,8 +2,9 @@
Declaration of `NodeConverter` class.
"""
# pylint: disable=no-member,no-name-in-module
# pylint: disable=no-member,no-name-in-module,too-many-lines
from copy import deepcopy
from typing import Dict, List, Tuple
import numpy as np
@@ -13,6 +14,7 @@ from mlir.dialects import arith, linalg, tensor
from mlir.ir import (
ArrayAttr,
Attribute,
BlockArgument,
BoolAttr,
Context,
DenseElementsAttr,
@@ -24,10 +26,10 @@ from mlir.ir import (
Type,
)
from ..dtypes import Integer
from ..dtypes import Integer, UnsignedInteger
from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..values import Value
from ..values import EncryptedScalar, Value
from .utils import construct_deduplicated_tables
# pylint: enable=no-member,no-name-in-module
@@ -38,6 +40,8 @@ class NodeConverter:
NodeConverter class, to convert computation graph nodes to their MLIR equivalent.
"""
# pylint: disable=too-many-instance-attributes
ctx: Context
graph: Graph
node: Node
@@ -50,6 +54,9 @@ class NodeConverter:
nodes_to_mlir_names: Dict[Node, str]
mlir_names_to_mlir_types: Dict[str, str]
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
direct_replacements: Dict[str, str]
# pylint: enable=too-many-instance-attributes
@staticmethod
def value_to_mlir_type(ctx: Context, value: Value) -> Type:
@@ -83,6 +90,25 @@ class NodeConverter:
raise ValueError(f"{value} cannot be converted to MLIR") # pragma: no cover
@staticmethod
def mlir_name(result: OpResult) -> str:
"""
Extract the MLIR variable name of an `OpResult`.
Args:
result (OpResult):
op result to extract the name
Returns:
str:
MLIR variable name of `result`
"""
if isinstance(result, BlockArgument):
return f"%arg{result.arg_number}"
return str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
def __init__(
self,
ctx: Context,
@@ -92,6 +118,7 @@ class NodeConverter:
nodes_to_mlir_names: Dict[OpResult, str],
mlir_names_to_mlir_types: Dict[str, str],
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
direct_replacements: Dict[str, str],
):
self.ctx = ctx
self.graph = graph
@@ -114,6 +141,7 @@ class NodeConverter:
self.nodes_to_mlir_names = nodes_to_mlir_names
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
self.direct_replacements = direct_replacements
def convert(self) -> OpResult:
"""
@@ -127,65 +155,68 @@ class NodeConverter:
# pylint: disable=too-many-branches,too-many-statements
if self.node.operation == Operation.Constant:
result = self.convert_constant()
result = self._convert_constant()
else:
assert_that(self.node.operation == Operation.Generic)
name = self.node.properties["name"]
if name == "add":
result = self.convert_add()
result = self._convert_add()
elif name == "array":
result = self._convert_array()
elif name == "concatenate":
result = self.convert_concat()
result = self._convert_concat()
elif name == "conv1d":
result = self.convert_conv1d()
result = self._convert_conv1d()
elif name == "conv2d":
result = self.convert_conv2d()
result = self._convert_conv2d()
elif name == "conv3d":
result = self.convert_conv3d()
result = self._convert_conv3d()
elif name == "dot":
result = self.convert_dot()
result = self._convert_dot()
elif name == "index.static":
result = self.convert_static_indexing()
result = self._convert_static_indexing()
elif name == "matmul":
result = self.convert_matmul()
result = self._convert_matmul()
elif name == "multiply":
result = self.convert_mul()
result = self._convert_mul()
elif name == "negative":
result = self.convert_neg()
result = self._convert_neg()
elif name == "ones":
result = self.convert_ones()
result = self._convert_ones()
elif name == "reshape":
result = self.convert_reshape()
result = self._convert_reshape()
elif name == "subtract":
result = self.convert_sub()
result = self._convert_sub()
elif name == "sum":
result = self.convert_sum()
result = self._convert_sum()
elif name == "transpose":
result = self.convert_transpose()
result = self._convert_transpose()
elif name == "zeros":
result = self.convert_zeros()
result = self._convert_zeros()
else:
assert_that(self.node.converted_to_table_lookup)
result = self.convert_tlu()
result = self._convert_tlu()
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
mlir_name = NodeConverter.mlir_name(result)
self.nodes_to_mlir_names[self.node] = mlir_name
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
@@ -204,7 +235,7 @@ class NodeConverter:
# pylint: enable=too-many-branches
def convert_add(self) -> OpResult:
def _convert_add(self) -> OpResult:
"""
Convert "add" node to its corresponding MLIR representation.
@@ -232,7 +263,67 @@ class NodeConverter:
return result
def convert_concat(self) -> OpResult:
def _convert_array(self) -> OpResult:
"""
Convert "array" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
preds = self.preds
number_of_values = len(preds)
intermediate_value = deepcopy(self.node.output)
intermediate_value.shape = (number_of_values,)
intermediate_type = NodeConverter.value_to_mlir_type(self.ctx, intermediate_value)
pred_names = []
for pred, value in zip(preds, self.node.inputs):
if value.is_encrypted:
pred_names.append(NodeConverter.mlir_name(pred))
continue
assert isinstance(value.dtype, Integer)
zero_value = EncryptedScalar(UnsignedInteger(value.dtype.bit_width - 1))
zero_type = NodeConverter.value_to_mlir_type(self.ctx, zero_value)
zero = fhe.ZeroEintOp(zero_type).result
encrypted_pred = fhe.AddEintIntOp(zero_type, zero, pred).result
pred_names.append(NodeConverter.mlir_name(encrypted_pred))
# `placeholder_result` will be replaced textually by `actual_value` below in graph converter
# `tensor.from_elements` cannot be created from python bindings
# that's why we use placeholder values and text manipulation
placeholder_result = fhe.ZeroTensorOp(intermediate_type).result
placeholder_result_name = NodeConverter.mlir_name(placeholder_result)
actual_value = f"tensor.from_elements {', '.join(pred_names)} : {intermediate_type}"
self.direct_replacements[placeholder_result_name] = actual_value
if self.node.output.shape == (number_of_values,):
return placeholder_result
index_type = IndexType.parse("index")
return linalg.TensorExpandShapeOp(
resulting_type,
placeholder_result,
ArrayAttr.get(
[
ArrayAttr.get(
[IntegerAttr.get(index_type, i) for i in range(len(self.node.output.shape))]
)
]
),
).result
def _convert_concat(self) -> OpResult:
"""
Convert "concatenate" node to its corresponding MLIR representation.
@@ -287,7 +378,7 @@ class NodeConverter:
IntegerAttr.get(IntegerType.get_signless(64), 0),
).result
def convert_constant(self) -> OpResult:
def _convert_constant(self) -> OpResult:
"""
Convert Operation.Constant node to its corresponding MLIR representation.
@@ -315,7 +406,7 @@ class NodeConverter:
return arith.ConstantOp(resulting_type, attr).result
def convert_conv1d(self) -> OpResult:
def _convert_conv1d(self) -> OpResult:
"""
Convert "conv1d" node to its corresponding MLIR representation.
@@ -326,7 +417,7 @@ class NodeConverter:
raise NotImplementedError("conv1d conversion to MLIR is not yet implemented")
def convert_conv2d(self) -> OpResult:
def _convert_conv2d(self) -> OpResult:
"""
Convert "conv2d" node to its corresponding MLIR representation.
@@ -362,7 +453,7 @@ class NodeConverter:
return fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result
def convert_conv3d(self) -> OpResult:
def _convert_conv3d(self) -> OpResult:
"""
Convert "conv3d" node to its corresponding MLIR representation.
@@ -373,7 +464,7 @@ class NodeConverter:
raise NotImplementedError("conv3d conversion to MLIR is not yet implemented")
def convert_dot(self) -> OpResult:
def _convert_dot(self) -> OpResult:
"""
Convert "dot" node to its corresponding MLIR representation.
@@ -402,7 +493,7 @@ class NodeConverter:
return result
def convert_matmul(self) -> OpResult:
def _convert_matmul(self) -> OpResult:
"""Convert a MatMul node to its corresponding MLIR representation.
Returns:
@@ -424,7 +515,7 @@ class NodeConverter:
return result
def convert_mul(self) -> OpResult:
def _convert_mul(self) -> OpResult:
"""
Convert "multiply" node to its corresponding MLIR representation.
@@ -446,7 +537,7 @@ class NodeConverter:
return result
def convert_neg(self) -> OpResult:
def _convert_neg(self) -> OpResult:
"""
Convert "negative" node to its corresponding MLIR representation.
@@ -465,7 +556,7 @@ class NodeConverter:
return result
def convert_ones(self) -> OpResult:
def _convert_ones(self) -> OpResult:
"""
Convert "ones" node to its corresponding MLIR representation.
@@ -508,7 +599,7 @@ class NodeConverter:
return result
def convert_reshape(self) -> OpResult:
def _convert_reshape(self) -> OpResult:
"""
Convert "reshape" node to its corresponding MLIR representation.
@@ -627,7 +718,7 @@ class NodeConverter:
),
).result
def convert_static_indexing(self) -> OpResult:
def _convert_static_indexing(self) -> OpResult:
"""
Convert "index.static" node to its corresponding MLIR representation.
@@ -749,7 +840,7 @@ class NodeConverter:
),
).result
def convert_sub(self) -> OpResult:
def _convert_sub(self) -> OpResult:
"""
Convert "subtract" node to its corresponding MLIR representation.
@@ -768,7 +859,7 @@ class NodeConverter:
return result
def convert_sum(self) -> OpResult:
def _convert_sum(self) -> OpResult:
"""
Convert "sum" node to its corresponding MLIR representation.
@@ -799,7 +890,7 @@ class NodeConverter:
BoolAttr.get(keep_dims),
).result
def convert_tlu(self) -> OpResult:
def _convert_tlu(self) -> OpResult:
"""
Convert Operation.Generic node to its corresponding MLIR representation.
@@ -880,7 +971,7 @@ class NodeConverter:
return result
def convert_transpose(self) -> OpResult:
def _convert_transpose(self) -> OpResult:
"""
Convert "transpose" node to its corresponding MLIR representation.
@@ -894,7 +985,7 @@ class NodeConverter:
return fhelinalg.TransposeOp(resulting_type, *preds).result
def convert_zeros(self) -> OpResult:
def _convert_zeros(self) -> OpResult:
"""
Convert "zeros" node to its corresponding MLIR representation.

View File

@@ -257,6 +257,12 @@ class Node:
else:
args = deepcopy(predecessors)
if name == "array":
values = str(np.array(predecessors).reshape(self.output.shape).tolist()).replace(
"'", ""
)
return f"array({format_constant(values, maximum_constant_length)})"
args.extend(
format_constant(value, maximum_constant_length) for value in self.properties["args"]
)
@@ -300,6 +306,7 @@ class Node:
return self.operation == Operation.Generic and self.properties["name"] not in [
"add",
"array",
"concatenate",
"conv1d",
"conv2d",

View File

@@ -78,7 +78,7 @@ class Tracer:
continue
try:
sanitized_tracers.append(Tracer._sanitize(tracer))
sanitized_tracers.append(Tracer.sanitize(tracer))
except Exception as error:
raise ValueError(
f"Function '{function.__name__}' "
@@ -149,9 +149,21 @@ class Tracer:
return id(self)
@staticmethod
def _sanitize(value: Any) -> Any:
def sanitize(value: Any) -> Any:
"""
Try to create a tracer from a value.
Args:
value (Any):
value to use
Returns:
Any:
resulting tracer
"""
if isinstance(value, tuple):
return tuple(Tracer._sanitize(item) for item in value)
return tuple(Tracer.sanitize(item) for item in value)
if isinstance(value, Tracer):
return value
@@ -372,7 +384,7 @@ class Tracer:
"""
if method == "__call__":
sanitized_args = [self._sanitize(arg) for arg in args]
sanitized_args = [self.sanitize(arg) for arg in args]
return Tracer._trace_numpy_operation(ufunc, *sanitized_args, **kwargs)
raise RuntimeError("Only __call__ hook is supported for numpy ufuncs")
@@ -385,65 +397,65 @@ class Tracer:
"""
if func is np.reshape:
sanitized_args = [self._sanitize(args[0])]
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["newshape"] = args[1]
elif func is np.transpose:
sanitized_args = [self._sanitize(args[0])]
sanitized_args = [self.sanitize(args[0])]
if len(args) > 1:
kwargs["axes"] = args[1]
else:
sanitized_args = [self._sanitize(arg) for arg in args]
sanitized_args = [self.sanitize(arg) for arg in args]
return Tracer._trace_numpy_operation(func, *sanitized_args, **kwargs)
def __add__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.add, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.add, self, self.sanitize(other))
def __radd__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.add, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.add, self.sanitize(other), self)
def __sub__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.subtract, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.subtract, self, self.sanitize(other))
def __rsub__(self, other) -> "Tracer":
return Tracer._trace_numpy_operation(np.subtract, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.subtract, self.sanitize(other), self)
def __mul__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.multiply, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.multiply, self, self.sanitize(other))
def __rmul__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.multiply, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.multiply, self.sanitize(other), self)
def __truediv__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.true_divide, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.true_divide, self, self.sanitize(other))
def __rtruediv__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.true_divide, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.true_divide, self.sanitize(other), self)
def __floordiv__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.floor_divide, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.floor_divide, self, self.sanitize(other))
def __rfloordiv__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.floor_divide, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.floor_divide, self.sanitize(other), self)
def __pow__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.power, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.power, self, self.sanitize(other))
def __rpow__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.power, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.power, self.sanitize(other), self)
def __mod__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.mod, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.mod, self, self.sanitize(other))
def __rmod__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.mod, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.mod, self.sanitize(other), self)
def __matmul__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.matmul, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.matmul, self, self.sanitize(other))
def __rmatmul__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.matmul, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.matmul, self.sanitize(other), self)
def __neg__(self) -> "Tracer":
return Tracer._trace_numpy_operation(np.negative, self)
@@ -464,59 +476,59 @@ class Tracer:
return Tracer._trace_numpy_operation(np.invert, self)
def __and__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_and, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.bitwise_and, self, self.sanitize(other))
def __rand__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_and, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.bitwise_and, self.sanitize(other), self)
def __or__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_or, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.bitwise_or, self, self.sanitize(other))
def __ror__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_or, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.bitwise_or, self.sanitize(other), self)
def __xor__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_xor, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.bitwise_xor, self, self.sanitize(other))
def __rxor__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.bitwise_xor, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.bitwise_xor, self.sanitize(other), self)
def __lshift__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.left_shift, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.left_shift, self, self.sanitize(other))
def __rlshift__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.left_shift, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.left_shift, self.sanitize(other), self)
def __rshift__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.right_shift, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.right_shift, self, self.sanitize(other))
def __rrshift__(self, other: Any) -> "Tracer":
return Tracer._trace_numpy_operation(np.right_shift, self._sanitize(other), self)
return Tracer._trace_numpy_operation(np.right_shift, self.sanitize(other), self)
def __gt__(self, other: Any) -> "Tracer": # type: ignore
return Tracer._trace_numpy_operation(np.greater, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.greater, self, self.sanitize(other))
def __ge__(self, other: Any) -> "Tracer": # type: ignore
return Tracer._trace_numpy_operation(np.greater_equal, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.greater_equal, self, self.sanitize(other))
def __lt__(self, other: Any) -> "Tracer": # type: ignore
return Tracer._trace_numpy_operation(np.less, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.less, self, self.sanitize(other))
def __le__(self, other: Any) -> "Tracer": # type: ignore
return Tracer._trace_numpy_operation(np.less_equal, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.less_equal, self, self.sanitize(other))
def __eq__(self, other: Any) -> Union[bool, "Tracer"]: # type: ignore
return (
self is other
if not self._is_tracing
else Tracer._trace_numpy_operation(np.equal, self, self._sanitize(other))
else Tracer._trace_numpy_operation(np.equal, self, self.sanitize(other))
)
def __ne__(self, other: Any) -> Union[bool, "Tracer"]: # type: ignore
return (
self is not other
if not self._is_tracing
else Tracer._trace_numpy_operation(np.not_equal, self, self._sanitize(other))
else Tracer._trace_numpy_operation(np.not_equal, self, self.sanitize(other))
)
def astype(self, dtype: DTypeLike) -> "Tracer":
@@ -552,7 +564,7 @@ class Tracer:
"""
return Tracer._trace_numpy_operation(
np.clip, self, self._sanitize(minimum), self._sanitize(maximum)
np.clip, self, self.sanitize(minimum), self.sanitize(maximum)
)
def dot(self, other: Any) -> "Tracer":
@@ -560,7 +572,7 @@ class Tracer:
Trace numpy.ndarray.dot().
"""
return Tracer._trace_numpy_operation(np.dot, self, self._sanitize(other))
return Tracer._trace_numpy_operation(np.dot, self, self.sanitize(other))
def flatten(self) -> "Tracer":
"""

View File

@@ -0,0 +1,61 @@
"""
Tests of execution of array operation.
"""
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"function,parameters",
[
pytest.param(
lambda x: cnp.array([x, x + 1, 1]),
{
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
},
id="cnp.array([x, x + 1, 1])",
),
pytest.param(
lambda x, y: cnp.array([x, y]),
{
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
"y": {"range": [0, 10], "status": "clear", "shape": ()},
},
id="cnp.array([x, y])",
),
pytest.param(
lambda x, y: cnp.array([[x, y], [y, x]]),
{
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
"y": {"range": [0, 10], "status": "clear", "shape": ()},
},
id="cnp.array([[x, y], [y, x]])",
),
pytest.param(
lambda x, y, z: cnp.array([[x, 1], [y, 2], [z, 3]]),
{
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
"y": {"range": [0, 10], "status": "clear", "shape": ()},
"z": {"range": [0, 10], "status": "encrypted", "shape": ()},
},
id="cnp.array([[x, 1], [y, 2], [z, 3]])",
),
],
)
def test_array(function, parameters, helpers):
"""
Test array.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
sample = helpers.generate_sample(parameters)
helpers.check_execution(circuit, function, sample)

View File

@@ -0,0 +1,37 @@
"""
Tests of LookupTable.
"""
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"function,parameters,expected_error",
[
pytest.param(
lambda x, y: cnp.array([x, y]),
{
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
"y": {"range": [0, 10], "status": "encrypted", "shape": (2, 3)},
},
"Encrypted arrays can only be created from scalars",
),
],
)
def test_bad_array(function, parameters, expected_error, helpers):
"""
Test array with bad parameters.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = cnp.Compiler(function, parameter_encryption_statuses)
with pytest.raises(ValueError) as excinfo:
inputset = helpers.generate_inputset(parameters)
compiler.compile(inputset, configuration)
assert str(excinfo.value) == expected_error

View File

@@ -7,7 +7,7 @@ import pytest
from concrete.numpy.dtypes import UnsignedInteger
from concrete.numpy.representation import Node
from concrete.numpy.values import EncryptedScalar, EncryptedTensor, Value
from concrete.numpy.values import ClearScalar, EncryptedScalar, EncryptedTensor, Value
@pytest.mark.parametrize(
@@ -191,6 +191,21 @@ def test_node_bad_call(node, args, expected_error, expected_message):
["%0", "%1", "%2"],
"concatenate((%0, %1, %2), axis=1)",
),
pytest.param(
Node.generic(
name="array",
inputs=[
EncryptedScalar(UnsignedInteger(3)),
ClearScalar(UnsignedInteger(3)),
ClearScalar(UnsignedInteger(3)),
EncryptedScalar(UnsignedInteger(3)),
],
output=EncryptedTensor(UnsignedInteger(3), shape=(2, 2)),
operation=lambda *args: np.array(args).reshape((2, 2)),
),
["%0", "%1", "%2", "%3"],
"array([[%0, %1], [%2, %3]])",
),
],
)
def test_node_format(node, predecessors, expected_result):