feat: support assignments to tensors

This commit is contained in:
Umut
2022-07-27 09:10:20 +02:00
parent 48014ed60a
commit ef7e0d762f
7 changed files with 847 additions and 2 deletions

View File

@@ -26,7 +26,7 @@ from mlir.ir import (
from ..dtypes import Integer, SignedInteger
from ..internal.utils import assert_that
from ..representation import Graph, Node, Operation
from ..values import ClearScalar
from ..values import ClearScalar, EncryptedScalar
from .node_converter import NodeConverter
from .utils import MAXIMUM_BIT_WIDTH
@@ -88,6 +88,10 @@ class GraphConverter:
assert_that(len(inputs) > 0)
assert_that(all(input.is_scalar for input in inputs))
elif name == "assign.static":
if not inputs[0].is_encrypted:
return "only assignment to encrypted tensors are supported"
elif name == "broadcast_to":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:
@@ -303,6 +307,116 @@ class GraphConverter:
nx_graph.add_edge(add_offset, node, input_idx=variable_input_index)
@staticmethod
def _broadcast_assignments(graph: Graph):
"""
Broadcast assignments.
Args:
graph (Graph):
computation graph to transform
"""
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
shape = node.inputs[0].shape
index = node.properties["kwargs"]["index"]
assert_that(isinstance(index, tuple))
while len(index) < len(shape):
index = (*index, slice(None, None, None))
required_value_shape_list = []
for i, indexing_element in enumerate(index):
if isinstance(indexing_element, slice):
n = len(np.zeros(shape[i])[indexing_element])
required_value_shape_list.append(n)
else:
required_value_shape_list.append(1)
required_value_shape = tuple(required_value_shape_list)
actual_value_shape = node.inputs[1].shape
if required_value_shape != actual_value_shape:
preds = graph.ordered_preds_of(node)
pred_to_modify = preds[1]
modified_value = deepcopy(pred_to_modify.output)
modified_value.shape = required_value_shape
try:
np.broadcast_to(np.zeros(actual_value_shape), required_value_shape)
modified_value.is_encrypted = True
modified_value.dtype = node.output.dtype
modified_pred = Node.generic(
"broadcast_to",
[pred_to_modify.output],
modified_value,
np.broadcast_to,
kwargs={"shape": required_value_shape},
)
except Exception: # pylint: disable=broad-except
np.reshape(np.zeros(actual_value_shape), required_value_shape)
modified_pred = Node.generic(
"reshape",
[pred_to_modify.output],
modified_value,
np.reshape,
kwargs={"newshape": required_value_shape},
)
nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0)
nx_graph.remove_edge(pred_to_modify, node)
nx_graph.add_edge(modified_pred, node, input_idx=1)
node.inputs[1] = modified_value
@staticmethod
def _encrypt_clear_assignments(graph: Graph):
"""
Encrypt clear assignments.
Args:
graph (Graph):
computation graph to transform
"""
nx_graph = graph.graph
for node in list(nx_graph.nodes):
if node.operation == Operation.Generic and node.properties["name"] == "assign.static":
assigned_value = node.inputs[1]
if assigned_value.is_clear:
preds = graph.ordered_preds_of(node)
assigned_pred = preds[1]
new_assigned_pred_value = deepcopy(assigned_value)
new_assigned_pred_value.is_encrypted = True
new_assigned_pred_value.dtype = preds[0].output.dtype
zero = Node.generic(
"zeros",
[],
EncryptedScalar(new_assigned_pred_value.dtype),
lambda: np.zeros((), dtype=np.int64),
)
new_assigned_pred = Node.generic(
"add",
[assigned_pred.output, zero.output],
new_assigned_pred_value,
np.add,
)
nx_graph.remove_edge(preds[1], node)
nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0)
nx_graph.add_edge(zero, new_assigned_pred, input_idx=1)
nx_graph.add_edge(new_assigned_pred, node, input_idx=1)
@staticmethod
def _tensorize_scalars_for_fhelinalg(graph: Graph):
"""
@@ -462,6 +576,8 @@ class GraphConverter:
GraphConverter._update_bit_widths(graph)
GraphConverter._offset_negative_lookup_table_inputs(graph)
GraphConverter._broadcast_assignments(graph)
GraphConverter._encrypt_clear_assignments(graph)
GraphConverter._tensorize_scalars_for_fhelinalg(graph)
from_elements_operations: Dict[OpResult, List[OpResult]] = {}

View File

@@ -157,6 +157,9 @@ class NodeConverter:
if name == "add":
result = self._convert_add()
elif name == "assign.static":
result = self._convert_static_assignment()
elif name == "array":
result = self._convert_array()
@@ -716,6 +719,68 @@ class NodeConverter:
),
).result
def _convert_static_assignment(self) -> OpResult:
"""
Convert "assign.static" node to its corresponding MLIR representation.
Returns:
OpResult:
in-memory MLIR representation corresponding to `self.node`
"""
input_value = self.node.inputs[0]
input_shape = input_value.shape
index = list(self.node.properties["kwargs"]["index"])
while len(index) < input_value.ndim:
index.append(slice(None, None, None))
output_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output)
offsets = []
sizes = []
strides = []
for indexing_element, dimension_size in zip(index, input_shape):
if isinstance(indexing_element, slice):
size = np.zeros(dimension_size)[indexing_element].shape[0]
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
offset = (
(
indexing_element.start
if indexing_element.start >= 0
else indexing_element.start + dimension_size
)
if isinstance(indexing_element.start, int)
else (0 if stride > 0 else dimension_size - 1)
)
else:
size = 1
stride = 1
offset = int(
indexing_element if indexing_element >= 0 else indexing_element + dimension_size
)
offsets.append(offset)
sizes.append(size)
strides.append(stride)
i64_type = IntegerType.get_signless(64)
return tensor.InsertSliceOp(
output_type,
self.preds[1],
self.preds[0],
[],
[],
[],
ArrayAttr.get([IntegerAttr.get(i64_type, value) for value in offsets]),
ArrayAttr.get([IntegerAttr.get(i64_type, value) for value in sizes]),
ArrayAttr.get([IntegerAttr.get(i64_type, value) for value in strides]),
).result
def _convert_static_indexing(self) -> OpResult:
"""
Convert "index.static" node to its corresponding MLIR representation.

View File

@@ -252,6 +252,11 @@ class Node:
elements = [format_indexing_element(element) for element in index]
return f"{predecessors[0]}[{', '.join(elements)}]"
if name == "assign.static":
index = self.properties["kwargs"]["index"]
elements = [format_indexing_element(element) for element in index]
return f"({predecessors[0]}[{', '.join(elements)}] = {predecessors[1]})"
if name == "concatenate":
args = [f"({', '.join(predecessors)})"]
else:
@@ -292,7 +297,14 @@ class Node:
assert_that(self.operation == Operation.Generic)
name = self.properties["name"]
return name if name != "index.static" else self.format([""])
if name == "index.static":
name = self.format([""])
if name == "assign.static":
name = self.format(["", ""])[1:-1]
return name
@property
def converted_to_table_lookup(self) -> bool:
@@ -307,6 +319,7 @@ class Node:
return self.operation == Operation.Generic and self.properties["name"] not in [
"add",
"array",
"assign.static",
"broadcast_to",
"concatenate",
"conv1d",

View File

@@ -26,6 +26,9 @@ class Tracer:
input_tracers: List["Tracer"]
output: Value
# property to keep track of assignments
last_version: Optional["Tracer"] = None
# variable to control the behavior of __eq__
# so that it can be traced but still allow
# using Tracers in dicts when not tracing
@@ -71,6 +74,12 @@ class Tracer:
if not isinstance(output_tracers, tuple):
output_tracers = (output_tracers,)
output_tracer_list = list(output_tracers)
for i, output_tracer in enumerate(output_tracer_list):
if isinstance(output_tracer, Tracer) and output_tracer.last_version is not None:
output_tracer_list[i] = output_tracer.last_version
output_tracers = tuple(output_tracer_list)
sanitized_tracers = []
for tracer in output_tracers:
if isinstance(tracer, Tracer):
@@ -145,6 +154,9 @@ class Tracer:
self.input_tracers = input_tracers
self.output = computation.output
for i, tracer in enumerate(self.input_tracers):
self.input_tracers[i] = tracer if tracer.last_version is None else tracer.last_version
def __hash__(self) -> int:
return id(self)
@@ -671,6 +683,57 @@ class Tracer:
)
return Tracer(computation, [self])
def __setitem__(
self,
index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]],
value: Any,
):
if not isinstance(index, tuple):
index = (index,)
for indexing_element in index:
valid = isinstance(indexing_element, (int, np.integer, slice))
if isinstance(indexing_element, slice):
if (
not (
indexing_element.start is None
or isinstance(indexing_element.start, (int, np.integer))
)
or not (
indexing_element.stop is None
or isinstance(indexing_element.stop, (int, np.integer))
)
or not (
indexing_element.step is None
or isinstance(indexing_element.step, (int, np.integer))
)
):
valid = False
if not valid:
raise ValueError(
f"Assigning to '{format_indexing_element(indexing_element)}' is not supported"
)
np.zeros(self.output.shape)[index] = 1
def assign(x, value, index):
x[index] = value
return x
sanitized_value = self.sanitize(value)
computation = Node.generic(
"assign.static",
[self.output, sanitized_value.output],
self.output,
assign,
kwargs={"index": index},
)
new_version = Tracer(computation, [self, sanitized_value])
self.last_version = new_version
@property
def shape(self) -> Tuple[int, ...]:
"""

View File

@@ -0,0 +1,530 @@
"""
Tests of execution of static assignment operation.
"""
import numpy as np
import pytest
import concrete.numpy as cnp
def assignment_case_0():
"""
Assignment test case.
"""
shape = (3,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[:] = value
return x
return shape, assign
def assignment_case_1():
"""
Assignment test case.
"""
shape = (3,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[0] = value
return x
return shape, assign
def assignment_case_2():
"""
Assignment test case.
"""
shape = (3,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[1] = value
return x
return shape, assign
def assignment_case_3():
"""
Assignment test case.
"""
shape = (3,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[2] = value
return x
return shape, assign
def assignment_case_4():
"""
Assignment test case.
"""
shape = (5,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[0:3] = value
return x
return shape, assign
def assignment_case_5():
"""
Assignment test case.
"""
shape = (5,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[1:4] = value
return x
return shape, assign
def assignment_case_6():
"""
Assignment test case.
"""
shape = (5,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[1:4:2] = value
return x
return shape, assign
def assignment_case_7():
"""
Assignment test case.
"""
shape = (10,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[::2] = value
return x
return shape, assign
def assignment_case_8():
"""
Assignment test case.
"""
shape = (5,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[2:0:-1] = value
return x
return shape, assign
def assignment_case_9():
"""
Assignment test case.
"""
shape = (5,)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[4:0:-2] = value
return x
return shape, assign
def assignment_case_10():
"""
Assignment test case.
"""
shape = (5,)
value = np.random.randint(0, 2**7, size=(3,))
def assign(x):
x[1:4] = value
return x
return shape, assign
def assignment_case_11():
"""
Assignment test case.
"""
shape = (5,)
value = np.random.randint(0, 2**7, size=(3,))
def assign(x):
x[4:1:-1] = value
return x
return shape, assign
def assignment_case_12():
"""
Assignment test case.
"""
shape = (10,)
value = np.random.randint(0, 2**7, size=(3,))
def assign(x):
x[1:7:2] = value
return x
return shape, assign
def assignment_case_13():
"""
Assignment test case.
"""
shape = (10,)
value = np.random.randint(0, 2**7, size=(3,))
def assign(x):
x[7:1:-2] = value
return x
return shape, assign
def assignment_case_14():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[0, 0] = value
return x
return shape, assign
def assignment_case_15():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[3, 1] = value
return x
return shape, assign
def assignment_case_16():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[0] = value
return x
return shape, assign
def assignment_case_17():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(4,))
def assign(x):
x[0] = value
return x
return shape, assign
def assignment_case_18():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(5,))
def assign(x):
x[:, 0] = value
return x
return shape, assign
def assignment_case_19():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(5,))
def assign(x):
x[:, 1] = value
return x
return shape, assign
def assignment_case_20():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[0:3, :] = value
return x
return shape, assign
def assignment_case_21():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(3, 4))
def assign(x):
x[0:3, :] = value
return x
return shape, assign
def assignment_case_22():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(4,))
def assign(x):
x[0:3, :] = value
return x
return shape, assign
def assignment_case_23():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(3,))
def assign(x):
x[0:3, 1:4] = value
return x
return shape, assign
def assignment_case_24():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(3, 3))
def assign(x):
x[0:3, 1:4] = value
return x
return shape, assign
def assignment_case_25():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(3, 3))
def assign(x):
x[4:1:-1, 3:0:-1] = value
return x
return shape, assign
def assignment_case_26():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(3,))
def assign(x):
x[3:0:-1, 0] = value
return x
return shape, assign
def assignment_case_27():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=(2,))
def assign(x):
x[0, 1:3] = value
return x
return shape, assign
def assignment_case_28():
"""
Assignment test case.
"""
shape = (5, 4)
value = np.random.randint(0, 2**7, size=())
def assign(x):
x[2:4, 1:3] = value
return x
return shape, assign
@pytest.mark.parametrize(
"shape,function",
[
pytest.param(*assignment_case_0()),
pytest.param(*assignment_case_1()),
pytest.param(*assignment_case_2()),
pytest.param(*assignment_case_3()),
pytest.param(*assignment_case_4()),
pytest.param(*assignment_case_5()),
pytest.param(*assignment_case_6()),
pytest.param(*assignment_case_7()),
pytest.param(*assignment_case_8()),
pytest.param(*assignment_case_9()),
pytest.param(*assignment_case_10()),
pytest.param(*assignment_case_11()),
pytest.param(*assignment_case_12()),
pytest.param(*assignment_case_13()),
pytest.param(*assignment_case_14()),
pytest.param(*assignment_case_15()),
pytest.param(*assignment_case_16()),
pytest.param(*assignment_case_17()),
pytest.param(*assignment_case_18()),
pytest.param(*assignment_case_19()),
pytest.param(*assignment_case_20()),
pytest.param(*assignment_case_21()),
pytest.param(*assignment_case_22()),
pytest.param(*assignment_case_23()),
pytest.param(*assignment_case_24()),
pytest.param(*assignment_case_25()),
pytest.param(*assignment_case_26()),
pytest.param(*assignment_case_27()),
pytest.param(*assignment_case_28()),
],
)
def test_static_assignment(shape, function, helpers):
"""
Test static assignment.
"""
configuration = helpers.configuration()
compiler = cnp.Compiler(function, {"x": "encrypted"})
inputset = [np.random.randint(0, 2**7, size=shape) for _ in range(100)]
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(0, 2**7, size=shape)
helpers.check_execution(circuit, function, sample)
def test_bad_static_assignment(helpers):
"""
Test static assingment with bad parameters.
"""
configuration = helpers.configuration()
# with float
# ----------
def f(x):
x[1.5] = 0
return x
compiler = cnp.Compiler(f, {"x": "encrypted"})
inputset = [np.random.randint(0, 2**3, size=(3,)) for _ in range(100)]
with pytest.raises(ValueError) as excinfo:
compiler.compile(inputset, configuration)
assert str(excinfo.value) == "Assigning to '1.5' is not supported"
# with bad slice
# --------------
def g(x):
x[slice(1.5, 2.5, None)] = 0
return x
compiler = cnp.Compiler(g, {"x": "encrypted"})
inputset = [np.random.randint(0, 2**3, size=(3,)) for _ in range(100)]
with pytest.raises(ValueError) as excinfo:
compiler.compile(inputset, configuration)
assert str(excinfo.value) == "Assigning to '1.5:2.5' is not supported"

View File

@@ -11,6 +11,15 @@ import concrete.onnx as connx
# pylint: disable=line-too-long
def assign(x):
"""
Simple assignment to a vector.
"""
x[0] = 0
return x
@pytest.mark.parametrize(
"function,encryption_statuses,inputset,expected_error,expected_message",
[
@@ -386,6 +395,24 @@ Function you are trying to compile cannot be converted to MLIR
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported
return %1
""", # noqa: E501
),
pytest.param(
assign,
{"x": "clear"},
[np.random.randint(0, 2, size=(3,)) for _ in range(100)],
RuntimeError,
"""
Function you are trying to compile cannot be converted to MLIR
%0 = x # ClearTensor<uint1, shape=(3,)>
%1 = 0 # ClearScalar<uint1>
%2 = (%0[0] = %1) # ClearTensor<uint1, shape=(3,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only assignment to encrypted tensors are supported
return %2
""", # noqa: E501
),
],

View File

@@ -206,6 +206,17 @@ def test_node_bad_call(node, args, expected_error, expected_message):
["%0", "%1", "%2", "%3"],
"array([[%0, %1], [%2, %3]])",
),
pytest.param(
Node.generic(
name="assign.static",
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
operation=lambda *args: args,
kwargs={"index": (1, 2)},
),
["%0", "%1"],
"(%0[1, 2] = %1)",
),
],
)
def test_node_format(node, predecessors, expected_result):
@@ -253,6 +264,26 @@ def test_node_format(node, predecessors, expected_result):
),
"concatenate",
),
pytest.param(
Node.generic(
name="index.static",
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
output=EncryptedTensor(UnsignedInteger(3), shape=()),
operation=lambda *args: args,
kwargs={"index": (1, 2)},
),
"□[1, 2]",
),
pytest.param(
Node.generic(
name="assign.static",
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))],
output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)),
operation=lambda *args: args,
kwargs={"index": (1, 2)},
),
"□[1, 2] = □",
),
],
)
def test_node_label(node, expected_result):