mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: create array extension
This commit is contained in:
@@ -15,6 +15,6 @@ from .compilation import (
|
||||
Server,
|
||||
compiler,
|
||||
)
|
||||
from .extensions import LookupTable, one, ones, univariate, zero, zeros
|
||||
from .extensions import LookupTable, array, one, ones, univariate, zero, zeros
|
||||
from .mlir.utils import MAXIMUM_BIT_WIDTH
|
||||
from .representation import Graph
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Provide additional features that are not present in numpy.
|
||||
"""
|
||||
|
||||
from .array import array
|
||||
from .ones import one, ones
|
||||
from .table import LookupTable
|
||||
from .univariate import univariate
|
||||
|
||||
58
concrete/numpy/extensions/array.py
Normal file
58
concrete/numpy/extensions/array.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Declaration of `array` function, to simplify creation of encrypted arrays.
|
||||
"""
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes.utils import combine_dtypes
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
|
||||
|
||||
def array(values: Any) -> Union[np.ndarray, Tracer]:
|
||||
"""
|
||||
Create an encrypted array from either encrypted or clear values.
|
||||
|
||||
Args:
|
||||
values (Any):
|
||||
array like object compatible with numpy to construct the resulting encrypted array
|
||||
|
||||
Returns:
|
||||
Union[np.ndarray, Tracer]:
|
||||
Tracer that respresents the operation during tracing
|
||||
ndarray with values otherwise
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
is_tracing = Tracer._is_tracing
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if not isinstance(values, np.ndarray):
|
||||
values = np.array(values)
|
||||
|
||||
if not is_tracing:
|
||||
return values
|
||||
|
||||
shape = values.shape
|
||||
values = values.flatten()
|
||||
|
||||
for i, value in enumerate(values):
|
||||
if not isinstance(value, Tracer):
|
||||
values[i] = Tracer.sanitize(value)
|
||||
|
||||
if not values[i].output.is_scalar:
|
||||
raise ValueError("Encrypted arrays can only be created from scalars")
|
||||
|
||||
dtype = combine_dtypes([value.output.dtype for value in values])
|
||||
is_encrypted = True
|
||||
|
||||
computation = Node.generic(
|
||||
"array",
|
||||
[value.output for value in values],
|
||||
Value(dtype, shape, is_encrypted),
|
||||
lambda *args: np.array(args).reshape(shape),
|
||||
)
|
||||
return Tracer(computation, values)
|
||||
@@ -83,6 +83,10 @@ class GraphConverter:
|
||||
if name == "add":
|
||||
assert_that(len(inputs) == 2)
|
||||
|
||||
elif name == "array":
|
||||
assert_that(len(inputs) > 0)
|
||||
assert_that(all(input.is_scalar for input in inputs))
|
||||
|
||||
elif name == "concatenate":
|
||||
if not all(input.is_encrypted for input in inputs):
|
||||
return "only all encrypted concatenate is supported"
|
||||
@@ -416,6 +420,9 @@ class GraphConverter:
|
||||
# { "%0": ["%c1_i5"] } == for %0 we need to convert %c1_i5 to 1d tensor
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]] = {}
|
||||
|
||||
# { "%0": "tensor.from_elements ..." } == we need to convert the part after "=" for %0
|
||||
direct_replacements: Dict[str, str] = {}
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
concretelang.register_dialects(ctx)
|
||||
|
||||
@@ -455,6 +462,7 @@ class GraphConverter:
|
||||
nodes_to_mlir_names,
|
||||
mlir_names_to_mlir_types,
|
||||
scalar_to_1d_tensor_conversion_hacks,
|
||||
direct_replacements,
|
||||
)
|
||||
ir_to_mlir[node] = node_converter.convert()
|
||||
|
||||
@@ -464,6 +472,12 @@ class GraphConverter:
|
||||
module_lines_after_hacks_are_applied = []
|
||||
for line in str(module).split("\n"):
|
||||
mlir_name = line.split("=")[0].strip()
|
||||
|
||||
if mlir_name in direct_replacements:
|
||||
new_value = direct_replacements[mlir_name]
|
||||
module_lines_after_hacks_are_applied.append(f" {mlir_name} = {new_value}")
|
||||
continue
|
||||
|
||||
if mlir_name not in scalar_to_1d_tensor_conversion_hacks:
|
||||
module_lines_after_hacks_are_applied.append(line)
|
||||
continue
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
Declaration of `NodeConverter` class.
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member,no-name-in-module
|
||||
# pylint: disable=no-member,no-name-in-module,too-many-lines
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -13,6 +14,7 @@ from mlir.dialects import arith, linalg, tensor
|
||||
from mlir.ir import (
|
||||
ArrayAttr,
|
||||
Attribute,
|
||||
BlockArgument,
|
||||
BoolAttr,
|
||||
Context,
|
||||
DenseElementsAttr,
|
||||
@@ -24,10 +26,10 @@ from mlir.ir import (
|
||||
Type,
|
||||
)
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..dtypes import Integer, UnsignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import Value
|
||||
from ..values import EncryptedScalar, Value
|
||||
from .utils import construct_deduplicated_tables
|
||||
|
||||
# pylint: enable=no-member,no-name-in-module
|
||||
@@ -38,6 +40,8 @@ class NodeConverter:
|
||||
NodeConverter class, to convert computation graph nodes to their MLIR equivalent.
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
ctx: Context
|
||||
graph: Graph
|
||||
node: Node
|
||||
@@ -50,6 +54,9 @@ class NodeConverter:
|
||||
nodes_to_mlir_names: Dict[Node, str]
|
||||
mlir_names_to_mlir_types: Dict[str, str]
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]]
|
||||
direct_replacements: Dict[str, str]
|
||||
|
||||
# pylint: enable=too-many-instance-attributes
|
||||
|
||||
@staticmethod
|
||||
def value_to_mlir_type(ctx: Context, value: Value) -> Type:
|
||||
@@ -83,6 +90,25 @@ class NodeConverter:
|
||||
|
||||
raise ValueError(f"{value} cannot be converted to MLIR") # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def mlir_name(result: OpResult) -> str:
|
||||
"""
|
||||
Extract the MLIR variable name of an `OpResult`.
|
||||
|
||||
Args:
|
||||
result (OpResult):
|
||||
op result to extract the name
|
||||
|
||||
Returns:
|
||||
str:
|
||||
MLIR variable name of `result`
|
||||
"""
|
||||
|
||||
if isinstance(result, BlockArgument):
|
||||
return f"%arg{result.arg_number}"
|
||||
|
||||
return str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: Context,
|
||||
@@ -92,6 +118,7 @@ class NodeConverter:
|
||||
nodes_to_mlir_names: Dict[OpResult, str],
|
||||
mlir_names_to_mlir_types: Dict[str, str],
|
||||
scalar_to_1d_tensor_conversion_hacks: Dict[str, List[str]],
|
||||
direct_replacements: Dict[str, str],
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.graph = graph
|
||||
@@ -114,6 +141,7 @@ class NodeConverter:
|
||||
self.nodes_to_mlir_names = nodes_to_mlir_names
|
||||
self.mlir_names_to_mlir_types = mlir_names_to_mlir_types
|
||||
self.scalar_to_1d_tensor_conversion_hacks = scalar_to_1d_tensor_conversion_hacks
|
||||
self.direct_replacements = direct_replacements
|
||||
|
||||
def convert(self) -> OpResult:
|
||||
"""
|
||||
@@ -127,65 +155,68 @@ class NodeConverter:
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
|
||||
if self.node.operation == Operation.Constant:
|
||||
result = self.convert_constant()
|
||||
result = self._convert_constant()
|
||||
else:
|
||||
assert_that(self.node.operation == Operation.Generic)
|
||||
|
||||
name = self.node.properties["name"]
|
||||
|
||||
if name == "add":
|
||||
result = self.convert_add()
|
||||
result = self._convert_add()
|
||||
|
||||
elif name == "array":
|
||||
result = self._convert_array()
|
||||
|
||||
elif name == "concatenate":
|
||||
result = self.convert_concat()
|
||||
result = self._convert_concat()
|
||||
|
||||
elif name == "conv1d":
|
||||
result = self.convert_conv1d()
|
||||
result = self._convert_conv1d()
|
||||
|
||||
elif name == "conv2d":
|
||||
result = self.convert_conv2d()
|
||||
result = self._convert_conv2d()
|
||||
|
||||
elif name == "conv3d":
|
||||
result = self.convert_conv3d()
|
||||
result = self._convert_conv3d()
|
||||
|
||||
elif name == "dot":
|
||||
result = self.convert_dot()
|
||||
result = self._convert_dot()
|
||||
|
||||
elif name == "index.static":
|
||||
result = self.convert_static_indexing()
|
||||
result = self._convert_static_indexing()
|
||||
|
||||
elif name == "matmul":
|
||||
result = self.convert_matmul()
|
||||
result = self._convert_matmul()
|
||||
|
||||
elif name == "multiply":
|
||||
result = self.convert_mul()
|
||||
result = self._convert_mul()
|
||||
|
||||
elif name == "negative":
|
||||
result = self.convert_neg()
|
||||
result = self._convert_neg()
|
||||
|
||||
elif name == "ones":
|
||||
result = self.convert_ones()
|
||||
result = self._convert_ones()
|
||||
|
||||
elif name == "reshape":
|
||||
result = self.convert_reshape()
|
||||
result = self._convert_reshape()
|
||||
|
||||
elif name == "subtract":
|
||||
result = self.convert_sub()
|
||||
result = self._convert_sub()
|
||||
|
||||
elif name == "sum":
|
||||
result = self.convert_sum()
|
||||
result = self._convert_sum()
|
||||
|
||||
elif name == "transpose":
|
||||
result = self.convert_transpose()
|
||||
result = self._convert_transpose()
|
||||
|
||||
elif name == "zeros":
|
||||
result = self.convert_zeros()
|
||||
result = self._convert_zeros()
|
||||
|
||||
else:
|
||||
assert_that(self.node.converted_to_table_lookup)
|
||||
result = self.convert_tlu()
|
||||
result = self._convert_tlu()
|
||||
|
||||
mlir_name = str(result).replace("Value(", "").split("=", maxsplit=1)[0].strip()
|
||||
mlir_name = NodeConverter.mlir_name(result)
|
||||
|
||||
self.nodes_to_mlir_names[self.node] = mlir_name
|
||||
self.mlir_names_to_mlir_types[mlir_name] = str(result.type)
|
||||
@@ -204,7 +235,7 @@ class NodeConverter:
|
||||
|
||||
# pylint: enable=too-many-branches
|
||||
|
||||
def convert_add(self) -> OpResult:
|
||||
def _convert_add(self) -> OpResult:
|
||||
"""
|
||||
Convert "add" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -232,7 +263,67 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_concat(self) -> OpResult:
|
||||
def _convert_array(self) -> OpResult:
|
||||
"""
|
||||
Convert "array" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
preds = self.preds
|
||||
|
||||
number_of_values = len(preds)
|
||||
|
||||
intermediate_value = deepcopy(self.node.output)
|
||||
intermediate_value.shape = (number_of_values,)
|
||||
|
||||
intermediate_type = NodeConverter.value_to_mlir_type(self.ctx, intermediate_value)
|
||||
|
||||
pred_names = []
|
||||
for pred, value in zip(preds, self.node.inputs):
|
||||
if value.is_encrypted:
|
||||
pred_names.append(NodeConverter.mlir_name(pred))
|
||||
continue
|
||||
|
||||
assert isinstance(value.dtype, Integer)
|
||||
|
||||
zero_value = EncryptedScalar(UnsignedInteger(value.dtype.bit_width - 1))
|
||||
zero_type = NodeConverter.value_to_mlir_type(self.ctx, zero_value)
|
||||
zero = fhe.ZeroEintOp(zero_type).result
|
||||
|
||||
encrypted_pred = fhe.AddEintIntOp(zero_type, zero, pred).result
|
||||
pred_names.append(NodeConverter.mlir_name(encrypted_pred))
|
||||
|
||||
# `placeholder_result` will be replaced textually by `actual_value` below in graph converter
|
||||
# `tensor.from_elements` cannot be created from python bindings
|
||||
# that's why we use placeholder values and text manipulation
|
||||
|
||||
placeholder_result = fhe.ZeroTensorOp(intermediate_type).result
|
||||
placeholder_result_name = NodeConverter.mlir_name(placeholder_result)
|
||||
|
||||
actual_value = f"tensor.from_elements {', '.join(pred_names)} : {intermediate_type}"
|
||||
self.direct_replacements[placeholder_result_name] = actual_value
|
||||
|
||||
if self.node.output.shape == (number_of_values,):
|
||||
return placeholder_result
|
||||
|
||||
index_type = IndexType.parse("index")
|
||||
return linalg.TensorExpandShapeOp(
|
||||
resulting_type,
|
||||
placeholder_result,
|
||||
ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get(
|
||||
[IntegerAttr.get(index_type, i) for i in range(len(self.node.output.shape))]
|
||||
)
|
||||
]
|
||||
),
|
||||
).result
|
||||
|
||||
def _convert_concat(self) -> OpResult:
|
||||
"""
|
||||
Convert "concatenate" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -287,7 +378,7 @@ class NodeConverter:
|
||||
IntegerAttr.get(IntegerType.get_signless(64), 0),
|
||||
).result
|
||||
|
||||
def convert_constant(self) -> OpResult:
|
||||
def _convert_constant(self) -> OpResult:
|
||||
"""
|
||||
Convert Operation.Constant node to its corresponding MLIR representation.
|
||||
|
||||
@@ -315,7 +406,7 @@ class NodeConverter:
|
||||
|
||||
return arith.ConstantOp(resulting_type, attr).result
|
||||
|
||||
def convert_conv1d(self) -> OpResult:
|
||||
def _convert_conv1d(self) -> OpResult:
|
||||
"""
|
||||
Convert "conv1d" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -326,7 +417,7 @@ class NodeConverter:
|
||||
|
||||
raise NotImplementedError("conv1d conversion to MLIR is not yet implemented")
|
||||
|
||||
def convert_conv2d(self) -> OpResult:
|
||||
def _convert_conv2d(self) -> OpResult:
|
||||
"""
|
||||
Convert "conv2d" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -362,7 +453,7 @@ class NodeConverter:
|
||||
|
||||
return fhelinalg.Conv2dOp(resulting_type, *preds, pads, strides, dilations).result
|
||||
|
||||
def convert_conv3d(self) -> OpResult:
|
||||
def _convert_conv3d(self) -> OpResult:
|
||||
"""
|
||||
Convert "conv3d" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -373,7 +464,7 @@ class NodeConverter:
|
||||
|
||||
raise NotImplementedError("conv3d conversion to MLIR is not yet implemented")
|
||||
|
||||
def convert_dot(self) -> OpResult:
|
||||
def _convert_dot(self) -> OpResult:
|
||||
"""
|
||||
Convert "dot" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -402,7 +493,7 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_matmul(self) -> OpResult:
|
||||
def _convert_matmul(self) -> OpResult:
|
||||
"""Convert a MatMul node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
@@ -424,7 +515,7 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_mul(self) -> OpResult:
|
||||
def _convert_mul(self) -> OpResult:
|
||||
"""
|
||||
Convert "multiply" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -446,7 +537,7 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_neg(self) -> OpResult:
|
||||
def _convert_neg(self) -> OpResult:
|
||||
"""
|
||||
Convert "negative" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -465,7 +556,7 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_ones(self) -> OpResult:
|
||||
def _convert_ones(self) -> OpResult:
|
||||
"""
|
||||
Convert "ones" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -508,7 +599,7 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_reshape(self) -> OpResult:
|
||||
def _convert_reshape(self) -> OpResult:
|
||||
"""
|
||||
Convert "reshape" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -627,7 +718,7 @@ class NodeConverter:
|
||||
),
|
||||
).result
|
||||
|
||||
def convert_static_indexing(self) -> OpResult:
|
||||
def _convert_static_indexing(self) -> OpResult:
|
||||
"""
|
||||
Convert "index.static" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -749,7 +840,7 @@ class NodeConverter:
|
||||
),
|
||||
).result
|
||||
|
||||
def convert_sub(self) -> OpResult:
|
||||
def _convert_sub(self) -> OpResult:
|
||||
"""
|
||||
Convert "subtract" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -768,7 +859,7 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_sum(self) -> OpResult:
|
||||
def _convert_sum(self) -> OpResult:
|
||||
"""
|
||||
Convert "sum" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -799,7 +890,7 @@ class NodeConverter:
|
||||
BoolAttr.get(keep_dims),
|
||||
).result
|
||||
|
||||
def convert_tlu(self) -> OpResult:
|
||||
def _convert_tlu(self) -> OpResult:
|
||||
"""
|
||||
Convert Operation.Generic node to its corresponding MLIR representation.
|
||||
|
||||
@@ -880,7 +971,7 @@ class NodeConverter:
|
||||
|
||||
return result
|
||||
|
||||
def convert_transpose(self) -> OpResult:
|
||||
def _convert_transpose(self) -> OpResult:
|
||||
"""
|
||||
Convert "transpose" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -894,7 +985,7 @@ class NodeConverter:
|
||||
|
||||
return fhelinalg.TransposeOp(resulting_type, *preds).result
|
||||
|
||||
def convert_zeros(self) -> OpResult:
|
||||
def _convert_zeros(self) -> OpResult:
|
||||
"""
|
||||
Convert "zeros" node to its corresponding MLIR representation.
|
||||
|
||||
|
||||
@@ -257,6 +257,12 @@ class Node:
|
||||
else:
|
||||
args = deepcopy(predecessors)
|
||||
|
||||
if name == "array":
|
||||
values = str(np.array(predecessors).reshape(self.output.shape).tolist()).replace(
|
||||
"'", ""
|
||||
)
|
||||
return f"array({format_constant(values, maximum_constant_length)})"
|
||||
|
||||
args.extend(
|
||||
format_constant(value, maximum_constant_length) for value in self.properties["args"]
|
||||
)
|
||||
@@ -300,6 +306,7 @@ class Node:
|
||||
|
||||
return self.operation == Operation.Generic and self.properties["name"] not in [
|
||||
"add",
|
||||
"array",
|
||||
"concatenate",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
|
||||
@@ -78,7 +78,7 @@ class Tracer:
|
||||
continue
|
||||
|
||||
try:
|
||||
sanitized_tracers.append(Tracer._sanitize(tracer))
|
||||
sanitized_tracers.append(Tracer.sanitize(tracer))
|
||||
except Exception as error:
|
||||
raise ValueError(
|
||||
f"Function '{function.__name__}' "
|
||||
@@ -149,9 +149,21 @@ class Tracer:
|
||||
return id(self)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize(value: Any) -> Any:
|
||||
def sanitize(value: Any) -> Any:
|
||||
"""
|
||||
Try to create a tracer from a value.
|
||||
|
||||
Args:
|
||||
value (Any):
|
||||
value to use
|
||||
|
||||
Returns:
|
||||
Any:
|
||||
resulting tracer
|
||||
"""
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return tuple(Tracer._sanitize(item) for item in value)
|
||||
return tuple(Tracer.sanitize(item) for item in value)
|
||||
|
||||
if isinstance(value, Tracer):
|
||||
return value
|
||||
@@ -372,7 +384,7 @@ class Tracer:
|
||||
"""
|
||||
|
||||
if method == "__call__":
|
||||
sanitized_args = [self._sanitize(arg) for arg in args]
|
||||
sanitized_args = [self.sanitize(arg) for arg in args]
|
||||
return Tracer._trace_numpy_operation(ufunc, *sanitized_args, **kwargs)
|
||||
|
||||
raise RuntimeError("Only __call__ hook is supported for numpy ufuncs")
|
||||
@@ -385,65 +397,65 @@ class Tracer:
|
||||
"""
|
||||
|
||||
if func is np.reshape:
|
||||
sanitized_args = [self._sanitize(args[0])]
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["newshape"] = args[1]
|
||||
elif func is np.transpose:
|
||||
sanitized_args = [self._sanitize(args[0])]
|
||||
sanitized_args = [self.sanitize(args[0])]
|
||||
if len(args) > 1:
|
||||
kwargs["axes"] = args[1]
|
||||
else:
|
||||
sanitized_args = [self._sanitize(arg) for arg in args]
|
||||
sanitized_args = [self.sanitize(arg) for arg in args]
|
||||
|
||||
return Tracer._trace_numpy_operation(func, *sanitized_args, **kwargs)
|
||||
|
||||
def __add__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.add, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.add, self, self.sanitize(other))
|
||||
|
||||
def __radd__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.add, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.add, self.sanitize(other), self)
|
||||
|
||||
def __sub__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.subtract, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.subtract, self, self.sanitize(other))
|
||||
|
||||
def __rsub__(self, other) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.subtract, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.subtract, self.sanitize(other), self)
|
||||
|
||||
def __mul__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.multiply, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.multiply, self, self.sanitize(other))
|
||||
|
||||
def __rmul__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.multiply, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.multiply, self.sanitize(other), self)
|
||||
|
||||
def __truediv__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.true_divide, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.true_divide, self, self.sanitize(other))
|
||||
|
||||
def __rtruediv__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.true_divide, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.true_divide, self.sanitize(other), self)
|
||||
|
||||
def __floordiv__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.floor_divide, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.floor_divide, self, self.sanitize(other))
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.floor_divide, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.floor_divide, self.sanitize(other), self)
|
||||
|
||||
def __pow__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.power, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.power, self, self.sanitize(other))
|
||||
|
||||
def __rpow__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.power, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.power, self.sanitize(other), self)
|
||||
|
||||
def __mod__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.mod, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.mod, self, self.sanitize(other))
|
||||
|
||||
def __rmod__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.mod, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.mod, self.sanitize(other), self)
|
||||
|
||||
def __matmul__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.matmul, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.matmul, self, self.sanitize(other))
|
||||
|
||||
def __rmatmul__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.matmul, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.matmul, self.sanitize(other), self)
|
||||
|
||||
def __neg__(self) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.negative, self)
|
||||
@@ -464,59 +476,59 @@ class Tracer:
|
||||
return Tracer._trace_numpy_operation(np.invert, self)
|
||||
|
||||
def __and__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_and, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.bitwise_and, self, self.sanitize(other))
|
||||
|
||||
def __rand__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_and, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.bitwise_and, self.sanitize(other), self)
|
||||
|
||||
def __or__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_or, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.bitwise_or, self, self.sanitize(other))
|
||||
|
||||
def __ror__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_or, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.bitwise_or, self.sanitize(other), self)
|
||||
|
||||
def __xor__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_xor, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.bitwise_xor, self, self.sanitize(other))
|
||||
|
||||
def __rxor__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.bitwise_xor, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.bitwise_xor, self.sanitize(other), self)
|
||||
|
||||
def __lshift__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.left_shift, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.left_shift, self, self.sanitize(other))
|
||||
|
||||
def __rlshift__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.left_shift, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.left_shift, self.sanitize(other), self)
|
||||
|
||||
def __rshift__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.right_shift, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.right_shift, self, self.sanitize(other))
|
||||
|
||||
def __rrshift__(self, other: Any) -> "Tracer":
|
||||
return Tracer._trace_numpy_operation(np.right_shift, self._sanitize(other), self)
|
||||
return Tracer._trace_numpy_operation(np.right_shift, self.sanitize(other), self)
|
||||
|
||||
def __gt__(self, other: Any) -> "Tracer": # type: ignore
|
||||
return Tracer._trace_numpy_operation(np.greater, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.greater, self, self.sanitize(other))
|
||||
|
||||
def __ge__(self, other: Any) -> "Tracer": # type: ignore
|
||||
return Tracer._trace_numpy_operation(np.greater_equal, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.greater_equal, self, self.sanitize(other))
|
||||
|
||||
def __lt__(self, other: Any) -> "Tracer": # type: ignore
|
||||
return Tracer._trace_numpy_operation(np.less, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.less, self, self.sanitize(other))
|
||||
|
||||
def __le__(self, other: Any) -> "Tracer": # type: ignore
|
||||
return Tracer._trace_numpy_operation(np.less_equal, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.less_equal, self, self.sanitize(other))
|
||||
|
||||
def __eq__(self, other: Any) -> Union[bool, "Tracer"]: # type: ignore
|
||||
return (
|
||||
self is other
|
||||
if not self._is_tracing
|
||||
else Tracer._trace_numpy_operation(np.equal, self, self._sanitize(other))
|
||||
else Tracer._trace_numpy_operation(np.equal, self, self.sanitize(other))
|
||||
)
|
||||
|
||||
def __ne__(self, other: Any) -> Union[bool, "Tracer"]: # type: ignore
|
||||
return (
|
||||
self is not other
|
||||
if not self._is_tracing
|
||||
else Tracer._trace_numpy_operation(np.not_equal, self, self._sanitize(other))
|
||||
else Tracer._trace_numpy_operation(np.not_equal, self, self.sanitize(other))
|
||||
)
|
||||
|
||||
def astype(self, dtype: DTypeLike) -> "Tracer":
|
||||
@@ -552,7 +564,7 @@ class Tracer:
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(
|
||||
np.clip, self, self._sanitize(minimum), self._sanitize(maximum)
|
||||
np.clip, self, self.sanitize(minimum), self.sanitize(maximum)
|
||||
)
|
||||
|
||||
def dot(self, other: Any) -> "Tracer":
|
||||
@@ -560,7 +572,7 @@ class Tracer:
|
||||
Trace numpy.ndarray.dot().
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(np.dot, self, self._sanitize(other))
|
||||
return Tracer._trace_numpy_operation(np.dot, self, self.sanitize(other))
|
||||
|
||||
def flatten(self) -> "Tracer":
|
||||
"""
|
||||
|
||||
61
tests/execution/test_array.py
Normal file
61
tests/execution/test_array.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Tests of execution of array operation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: cnp.array([x, x + 1, 1]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
},
|
||||
id="cnp.array([x, x + 1, 1])",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: cnp.array([x, y]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
"y": {"range": [0, 10], "status": "clear", "shape": ()},
|
||||
},
|
||||
id="cnp.array([x, y])",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: cnp.array([[x, y], [y, x]]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
"y": {"range": [0, 10], "status": "clear", "shape": ()},
|
||||
},
|
||||
id="cnp.array([[x, y], [y, x]])",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y, z: cnp.array([[x, 1], [y, 2], [z, 3]]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
"y": {"range": [0, 10], "status": "clear", "shape": ()},
|
||||
"z": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
},
|
||||
id="cnp.array([[x, 1], [y, 2], [z, 3]])",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_array(function, parameters, helpers):
|
||||
"""
|
||||
Test array.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = helpers.generate_sample(parameters)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
37
tests/extensions/test_array.py
Normal file
37
tests/extensions/test_array.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Tests of LookupTable.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_error",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: cnp.array([x, y]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
"y": {"range": [0, 10], "status": "encrypted", "shape": (2, 3)},
|
||||
},
|
||||
"Encrypted arrays can only be created from scalars",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_bad_array(function, parameters, expected_error, helpers):
|
||||
"""
|
||||
Test array with bad parameters.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = cnp.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == expected_error
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from concrete.numpy.dtypes import UnsignedInteger
|
||||
from concrete.numpy.representation import Node
|
||||
from concrete.numpy.values import EncryptedScalar, EncryptedTensor, Value
|
||||
from concrete.numpy.values import ClearScalar, EncryptedScalar, EncryptedTensor, Value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -191,6 +191,21 @@ def test_node_bad_call(node, args, expected_error, expected_message):
|
||||
["%0", "%1", "%2"],
|
||||
"concatenate((%0, %1, %2), axis=1)",
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="array",
|
||||
inputs=[
|
||||
EncryptedScalar(UnsignedInteger(3)),
|
||||
ClearScalar(UnsignedInteger(3)),
|
||||
ClearScalar(UnsignedInteger(3)),
|
||||
EncryptedScalar(UnsignedInteger(3)),
|
||||
],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=(2, 2)),
|
||||
operation=lambda *args: np.array(args).reshape((2, 2)),
|
||||
),
|
||||
["%0", "%1", "%2", "%3"],
|
||||
"array([[%0, %1], [%2, %3]])",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_node_format(node, predecessors, expected_result):
|
||||
|
||||
Reference in New Issue
Block a user