mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: support assignments to tensors
This commit is contained in:
@@ -26,7 +26,7 @@ from mlir.ir import (
|
||||
from ..dtypes import Integer, SignedInteger
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..values import ClearScalar
|
||||
from ..values import ClearScalar, EncryptedScalar
|
||||
from .node_converter import NodeConverter
|
||||
from .utils import MAXIMUM_BIT_WIDTH
|
||||
|
||||
@@ -88,6 +88,10 @@ class GraphConverter:
|
||||
assert_that(len(inputs) > 0)
|
||||
assert_that(all(input.is_scalar for input in inputs))
|
||||
|
||||
elif name == "assign.static":
|
||||
if not inputs[0].is_encrypted:
|
||||
return "only assignment to encrypted tensors are supported"
|
||||
|
||||
elif name == "broadcast_to":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
@@ -303,6 +307,116 @@ class GraphConverter:
|
||||
|
||||
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_assignments(graph: Graph):
|
||||
"""
|
||||
Broadcast assignments.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to transform
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
|
||||
shape = node.inputs[0].shape
|
||||
index = node.properties["kwargs"]["index"]
|
||||
|
||||
assert_that(isinstance(index, tuple))
|
||||
while len(index) < len(shape):
|
||||
index = (*index, slice(None, None, None))
|
||||
|
||||
required_value_shape_list = []
|
||||
|
||||
for i, indexing_element in enumerate(index):
|
||||
if isinstance(indexing_element, slice):
|
||||
n = len(np.zeros(shape[i])[indexing_element])
|
||||
required_value_shape_list.append(n)
|
||||
else:
|
||||
required_value_shape_list.append(1)
|
||||
|
||||
required_value_shape = tuple(required_value_shape_list)
|
||||
actual_value_shape = node.inputs[1].shape
|
||||
|
||||
if required_value_shape != actual_value_shape:
|
||||
preds = graph.ordered_preds_of(node)
|
||||
pred_to_modify = preds[1]
|
||||
|
||||
modified_value = deepcopy(pred_to_modify.output)
|
||||
modified_value.shape = required_value_shape
|
||||
|
||||
try:
|
||||
np.broadcast_to(np.zeros(actual_value_shape), required_value_shape)
|
||||
modified_value.is_encrypted = True
|
||||
modified_value.dtype = node.output.dtype
|
||||
modified_pred = Node.generic(
|
||||
"broadcast_to",
|
||||
[pred_to_modify.output],
|
||||
modified_value,
|
||||
np.broadcast_to,
|
||||
kwargs={"shape": required_value_shape},
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
np.reshape(np.zeros(actual_value_shape), required_value_shape)
|
||||
modified_pred = Node.generic(
|
||||
"reshape",
|
||||
[pred_to_modify.output],
|
||||
modified_value,
|
||||
np.reshape,
|
||||
kwargs={"newshape": required_value_shape},
|
||||
)
|
||||
|
||||
nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0)
|
||||
|
||||
nx_graph.remove_edge(pred_to_modify, node)
|
||||
nx_graph.add_edge(modified_pred, node, input_idx=1)
|
||||
|
||||
node.inputs[1] = modified_value
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_clear_assignments(graph: Graph):
|
||||
"""
|
||||
Encrypt clear assignments.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
computation graph to transform
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
for node in list(nx_graph.nodes):
|
||||
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
|
||||
assigned_value = node.inputs[1]
|
||||
if assigned_value.is_clear:
|
||||
preds = graph.ordered_preds_of(node)
|
||||
assigned_pred = preds[1]
|
||||
|
||||
new_assigned_pred_value = deepcopy(assigned_value)
|
||||
new_assigned_pred_value.is_encrypted = True
|
||||
new_assigned_pred_value.dtype = preds[0].output.dtype
|
||||
|
||||
zero = Node.generic(
|
||||
"zeros",
|
||||
[],
|
||||
EncryptedScalar(new_assigned_pred_value.dtype),
|
||||
lambda: np.zeros((), dtype=np.int64),
|
||||
)
|
||||
|
||||
new_assigned_pred = Node.generic(
|
||||
"add",
|
||||
[assigned_pred.output, zero.output],
|
||||
new_assigned_pred_value,
|
||||
np.add,
|
||||
)
|
||||
|
||||
nx_graph.remove_edge(preds[1], node)
|
||||
|
||||
nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0)
|
||||
nx_graph.add_edge(zero, new_assigned_pred, input_idx=1)
|
||||
|
||||
nx_graph.add_edge(new_assigned_pred, node, input_idx=1)
|
||||
|
||||
@staticmethod
|
||||
def _tensorize_scalars_for_fhelinalg(graph: Graph):
|
||||
"""
|
||||
@@ -462,6 +576,8 @@ class GraphConverter:
|
||||
|
||||
GraphConverter._update_bit_widths(graph)
|
||||
GraphConverter._offset_negative_lookup_table_inputs(graph)
|
||||
GraphConverter._broadcast_assignments(graph)
|
||||
GraphConverter._encrypt_clear_assignments(graph)
|
||||
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
|
||||
|
||||
from_elements_operations: Dict[OpResult, List[OpResult]] = {}
|
||||
|
||||
@@ -157,6 +157,9 @@ class NodeConverter:
|
||||
if name == "add":
|
||||
result = self._convert_add()
|
||||
|
||||
elif name == "assign.static":
|
||||
result = self._convert_static_assignment()
|
||||
|
||||
elif name == "array":
|
||||
result = self._convert_array()
|
||||
|
||||
@@ -716,6 +719,68 @@ class NodeConverter:
|
||||
),
|
||||
).result
|
||||
|
||||
def _convert_static_assignment(self) -> OpResult:
|
||||
"""
|
||||
Convert "assign.static" node to its corresponding MLIR representation.
|
||||
|
||||
Returns:
|
||||
OpResult:
|
||||
in-memory MLIR representation corresponding to `self.node`
|
||||
"""
|
||||
|
||||
input_value = self.node.inputs[0]
|
||||
input_shape = input_value.shape
|
||||
|
||||
index = list(self.node.properties["kwargs"]["index"])
|
||||
|
||||
while len(index) < input_value.ndim:
|
||||
index.append(slice(None, None, None))
|
||||
|
||||
output_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
|
||||
|
||||
offsets = []
|
||||
sizes = []
|
||||
strides = []
|
||||
|
||||
for indexing_element, dimension_size in zip(index, input_shape):
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
size = np.zeros(dimension_size)[indexing_element].shape[0]
|
||||
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
|
||||
offset = (
|
||||
(
|
||||
indexing_element.start
|
||||
if indexing_element.start >= 0
|
||||
else indexing_element.start + dimension_size
|
||||
)
|
||||
if isinstance(indexing_element.start, int)
|
||||
else (0 if stride > 0 else dimension_size - 1)
|
||||
)
|
||||
|
||||
else:
|
||||
size = 1
|
||||
stride = 1
|
||||
offset = int(
|
||||
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
|
||||
)
|
||||
|
||||
offsets.append(offset)
|
||||
sizes.append(size)
|
||||
strides.append(stride)
|
||||
|
||||
i64_type = IntegerType.get_signless(64)
|
||||
return tensor.InsertSliceOp(
|
||||
output_type,
|
||||
self.preds[1],
|
||||
self.preds[0],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
ArrayAttr.get([IntegerAttr.get(i64_type, value) for value in offsets]),
|
||||
ArrayAttr.get([IntegerAttr.get(i64_type, value) for value in sizes]),
|
||||
ArrayAttr.get([IntegerAttr.get(i64_type, value) for value in strides]),
|
||||
).result
|
||||
|
||||
def _convert_static_indexing(self) -> OpResult:
|
||||
"""
|
||||
Convert "index.static" node to its corresponding MLIR representation.
|
||||
|
||||
@@ -252,6 +252,11 @@ class Node:
|
||||
elements = [format_indexing_element(element) for element in index]
|
||||
return f"{predecessors[0]}[{', '.join(elements)}]"
|
||||
|
||||
if name == "assign.static":
|
||||
index = self.properties["kwargs"]["index"]
|
||||
elements = [format_indexing_element(element) for element in index]
|
||||
return f"({predecessors[0]}[{', '.join(elements)}] = {predecessors[1]})"
|
||||
|
||||
if name == "concatenate":
|
||||
args = [f"({', '.join(predecessors)})"]
|
||||
else:
|
||||
@@ -292,7 +297,14 @@ class Node:
|
||||
assert_that(self.operation == Operation.Generic)
|
||||
|
||||
name = self.properties["name"]
|
||||
return name if name != "index.static" else self.format(["□"])
|
||||
|
||||
if name == "index.static":
|
||||
name = self.format(["□"])
|
||||
|
||||
if name == "assign.static":
|
||||
name = self.format(["□", "□"])[1:-1]
|
||||
|
||||
return name
|
||||
|
||||
@property
|
||||
def converted_to_table_lookup(self) -> bool:
|
||||
@@ -307,6 +319,7 @@ class Node:
|
||||
return self.operation == Operation.Generic and self.properties["name"] not in [
|
||||
"add",
|
||||
"array",
|
||||
"assign.static",
|
||||
"broadcast_to",
|
||||
"concatenate",
|
||||
"conv1d",
|
||||
|
||||
@@ -26,6 +26,9 @@ class Tracer:
|
||||
input_tracers: List["Tracer"]
|
||||
output: Value
|
||||
|
||||
# property to keep track of assignments
|
||||
last_version: Optional["Tracer"] = None
|
||||
|
||||
# variable to control the behavior of __eq__
|
||||
# so that it can be traced but still allow
|
||||
# using Tracers in dicts when not tracing
|
||||
@@ -71,6 +74,12 @@ class Tracer:
|
||||
if not isinstance(output_tracers, tuple):
|
||||
output_tracers = (output_tracers,)
|
||||
|
||||
output_tracer_list = list(output_tracers)
|
||||
for i, output_tracer in enumerate(output_tracer_list):
|
||||
if isinstance(output_tracer, Tracer) and output_tracer.last_version is not None:
|
||||
output_tracer_list[i] = output_tracer.last_version
|
||||
output_tracers = tuple(output_tracer_list)
|
||||
|
||||
sanitized_tracers = []
|
||||
for tracer in output_tracers:
|
||||
if isinstance(tracer, Tracer):
|
||||
@@ -145,6 +154,9 @@ class Tracer:
|
||||
self.input_tracers = input_tracers
|
||||
self.output = computation.output
|
||||
|
||||
for i, tracer in enumerate(self.input_tracers):
|
||||
self.input_tracers[i] = tracer if tracer.last_version is None else tracer.last_version
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
@@ -671,6 +683,57 @@ class Tracer:
|
||||
)
|
||||
return Tracer(computation, [self])
|
||||
|
||||
def __setitem__(
|
||||
self,
|
||||
index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]],
|
||||
value: Any,
|
||||
):
|
||||
if not isinstance(index, tuple):
|
||||
index = (index,)
|
||||
|
||||
for indexing_element in index:
|
||||
valid = isinstance(indexing_element, (int, np.integer, slice))
|
||||
|
||||
if isinstance(indexing_element, slice):
|
||||
if (
|
||||
not (
|
||||
indexing_element.start is None
|
||||
or isinstance(indexing_element.start, (int, np.integer))
|
||||
)
|
||||
or not (
|
||||
indexing_element.stop is None
|
||||
or isinstance(indexing_element.stop, (int, np.integer))
|
||||
)
|
||||
or not (
|
||||
indexing_element.step is None
|
||||
or isinstance(indexing_element.step, (int, np.integer))
|
||||
)
|
||||
):
|
||||
valid = False
|
||||
|
||||
if not valid:
|
||||
raise ValueError(
|
||||
f"Assigning to '{format_indexing_element(indexing_element)}' is not supported"
|
||||
)
|
||||
|
||||
np.zeros(self.output.shape)[index] = 1
|
||||
|
||||
def assign(x, value, index):
|
||||
x[index] = value
|
||||
return x
|
||||
|
||||
sanitized_value = self.sanitize(value)
|
||||
computation = Node.generic(
|
||||
"assign.static",
|
||||
[self.output, sanitized_value.output],
|
||||
self.output,
|
||||
assign,
|
||||
kwargs={"index": index},
|
||||
)
|
||||
new_version = Tracer(computation, [self, sanitized_value])
|
||||
|
||||
self.last_version = new_version
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
"""
|
||||
|
||||
530
tests/execution/test_static_assignment.py
Normal file
530
tests/execution/test_static_assignment.py
Normal file
@@ -0,0 +1,530 @@
|
||||
"""
|
||||
Tests of execution of static assignment operation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
|
||||
|
||||
def assignment_case_0():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (3,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[:] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_1():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (3,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[0] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_2():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (3,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[1] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_3():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (3,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[2] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_4():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[0:3] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_5():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[1:4] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_6():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[1:4:2] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_7():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (10,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[::2] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_8():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[2:0:-1] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_9():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5,)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[4:0:-2] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_10():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5,)
|
||||
value = np.random.randint(0, 2**7, size=(3,))
|
||||
|
||||
def assign(x):
|
||||
x[1:4] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_11():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5,)
|
||||
value = np.random.randint(0, 2**7, size=(3,))
|
||||
|
||||
def assign(x):
|
||||
x[4:1:-1] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_12():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (10,)
|
||||
value = np.random.randint(0, 2**7, size=(3,))
|
||||
|
||||
def assign(x):
|
||||
x[1:7:2] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_13():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (10,)
|
||||
value = np.random.randint(0, 2**7, size=(3,))
|
||||
|
||||
def assign(x):
|
||||
x[7:1:-2] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_14():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[0, 0] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_15():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[3, 1] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_16():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[0] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_17():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(4,))
|
||||
|
||||
def assign(x):
|
||||
x[0] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_18():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(5,))
|
||||
|
||||
def assign(x):
|
||||
x[:, 0] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_19():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(5,))
|
||||
|
||||
def assign(x):
|
||||
x[:, 1] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_20():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[0:3, :] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_21():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(3, 4))
|
||||
|
||||
def assign(x):
|
||||
x[0:3, :] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_22():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(4,))
|
||||
|
||||
def assign(x):
|
||||
x[0:3, :] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_23():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(3,))
|
||||
|
||||
def assign(x):
|
||||
x[0:3, 1:4] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_24():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(3, 3))
|
||||
|
||||
def assign(x):
|
||||
x[0:3, 1:4] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_25():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(3, 3))
|
||||
|
||||
def assign(x):
|
||||
x[4:1:-1, 3:0:-1] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_26():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(3,))
|
||||
|
||||
def assign(x):
|
||||
x[3:0:-1, 0] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_27():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=(2,))
|
||||
|
||||
def assign(x):
|
||||
x[0, 1:3] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
def assignment_case_28():
|
||||
"""
|
||||
Assignment test case.
|
||||
"""
|
||||
|
||||
shape = (5, 4)
|
||||
value = np.random.randint(0, 2**7, size=())
|
||||
|
||||
def assign(x):
|
||||
x[2:4, 1:3] = value
|
||||
return x
|
||||
|
||||
return shape, assign
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape,function",
|
||||
[
|
||||
pytest.param(*assignment_case_0()),
|
||||
pytest.param(*assignment_case_1()),
|
||||
pytest.param(*assignment_case_2()),
|
||||
pytest.param(*assignment_case_3()),
|
||||
pytest.param(*assignment_case_4()),
|
||||
pytest.param(*assignment_case_5()),
|
||||
pytest.param(*assignment_case_6()),
|
||||
pytest.param(*assignment_case_7()),
|
||||
pytest.param(*assignment_case_8()),
|
||||
pytest.param(*assignment_case_9()),
|
||||
pytest.param(*assignment_case_10()),
|
||||
pytest.param(*assignment_case_11()),
|
||||
pytest.param(*assignment_case_12()),
|
||||
pytest.param(*assignment_case_13()),
|
||||
pytest.param(*assignment_case_14()),
|
||||
pytest.param(*assignment_case_15()),
|
||||
pytest.param(*assignment_case_16()),
|
||||
pytest.param(*assignment_case_17()),
|
||||
pytest.param(*assignment_case_18()),
|
||||
pytest.param(*assignment_case_19()),
|
||||
pytest.param(*assignment_case_20()),
|
||||
pytest.param(*assignment_case_21()),
|
||||
pytest.param(*assignment_case_22()),
|
||||
pytest.param(*assignment_case_23()),
|
||||
pytest.param(*assignment_case_24()),
|
||||
pytest.param(*assignment_case_25()),
|
||||
pytest.param(*assignment_case_26()),
|
||||
pytest.param(*assignment_case_27()),
|
||||
pytest.param(*assignment_case_28()),
|
||||
],
|
||||
)
|
||||
def test_static_assignment(shape, function, helpers):
|
||||
"""
|
||||
Test static assignment.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = cnp.Compiler(function, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2**7, size=shape) for _ in range(100)]
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
sample = np.random.randint(0, 2**7, size=shape)
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
def test_bad_static_assignment(helpers):
|
||||
"""
|
||||
Test static assingment with bad parameters.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
# with float
|
||||
# ----------
|
||||
|
||||
def f(x):
|
||||
x[1.5] = 0
|
||||
return x
|
||||
|
||||
compiler = cnp.Compiler(f, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2**3, size=(3,)) for _ in range(100)]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == "Assigning to '1.5' is not supported"
|
||||
|
||||
# with bad slice
|
||||
# --------------
|
||||
|
||||
def g(x):
|
||||
x[slice(1.5, 2.5, None)] = 0
|
||||
return x
|
||||
|
||||
compiler = cnp.Compiler(g, {"x": "encrypted"})
|
||||
|
||||
inputset = [np.random.randint(0, 2**3, size=(3,)) for _ in range(100)]
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == "Assigning to '1.5:2.5' is not supported"
|
||||
@@ -11,6 +11,15 @@ import concrete.onnx as connx
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
|
||||
def assign(x):
|
||||
"""
|
||||
Simple assignment to a vector.
|
||||
"""
|
||||
|
||||
x[0] = 0
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,encryption_statuses,inputset,expected_error,expected_message",
|
||||
[
|
||||
@@ -386,6 +395,24 @@ Function you are trying to compile cannot be converted to MLIR
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported
|
||||
return %1
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
pytest.param(
|
||||
assign,
|
||||
{"x": "clear"},
|
||||
[np.random.randint(0, 2, size=(3,)) for _ in range(100)],
|
||||
RuntimeError,
|
||||
"""
|
||||
|
||||
Function you are trying to compile cannot be converted to MLIR
|
||||
|
||||
%0 = x # ClearTensor<uint1, shape=(3,)>
|
||||
%1 = 0 # ClearScalar<uint1>
|
||||
%2 = (%0[0] = %1) # ClearTensor<uint1, shape=(3,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only assignment to encrypted tensors are supported
|
||||
return %2
|
||||
|
||||
|
||||
""", # noqa: E501
|
||||
),
|
||||
],
|
||||
|
||||
@@ -206,6 +206,17 @@ def test_node_bad_call(node, args, expected_error, expected_message):
|
||||
["%0", "%1", "%2", "%3"],
|
||||
"array([[%0, %1], [%2, %3]])",
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="assign.static",
|
||||
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
|
||||
operation=lambda *args: args,
|
||||
kwargs={"index": (1, 2)},
|
||||
),
|
||||
["%0", "%1"],
|
||||
"(%0[1, 2] = %1)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_node_format(node, predecessors, expected_result):
|
||||
@@ -253,6 +264,26 @@ def test_node_format(node, predecessors, expected_result):
|
||||
),
|
||||
"concatenate",
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="index.static",
|
||||
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=()),
|
||||
operation=lambda *args: args,
|
||||
kwargs={"index": (1, 2)},
|
||||
),
|
||||
"□[1, 2]",
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="assign.static",
|
||||
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
|
||||
operation=lambda *args: args,
|
||||
kwargs={"index": (1, 2)},
|
||||
),
|
||||
"□[1, 2] = □",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_node_label(node, expected_result):
|
||||
|
||||
Reference in New Issue
Block a user