diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index aab61e13d..ae77b7c34 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -26,7 +26,7 @@ from mlir.ir import ( from ..dtypes import Integer, SignedInteger from ..internal.utils import assert_that from ..representation import Graph, Node, Operation -from ..values import ClearScalar +from ..values import ClearScalar, EncryptedScalar from .node_converter import NodeConverter from .utils import MAXIMUM_BIT_WIDTH @@ -88,6 +88,10 @@ class GraphConverter: assert_that(len(inputs) > 0) assert_that(all(input.is_scalar for input in inputs)) + elif name == "assign.static": + if not inputs[0].is_encrypted: + return "only assignment to encrypted tensors are supported" + elif name == "broadcast_to": assert_that(len(inputs) == 1) if not inputs[0].is_encrypted: @@ -303,6 +307,116 @@ class GraphConverter: nx_graph.add_edge(add_offset, node, input_idx=variable_input_index) + @staticmethod + def _broadcast_assignments(graph: Graph): + """ + Broadcast assignments. + + Args: + graph (Graph): + computation graph to transform + """ + + nx_graph = graph.graph + for node in list(nx_graph.nodes): + if node.operation == Operation.Generic and node.properties["name"] == "assign.static": + shape = node.inputs[0].shape + index = node.properties["kwargs"]["index"] + + assert_that(isinstance(index, tuple)) + while len(index) < len(shape): + index = (*index, slice(None, None, None)) + + required_value_shape_list = [] + + for i, indexing_element in enumerate(index): + if isinstance(indexing_element, slice): + n = len(np.zeros(shape[i])[indexing_element]) + required_value_shape_list.append(n) + else: + required_value_shape_list.append(1) + + required_value_shape = tuple(required_value_shape_list) + actual_value_shape = node.inputs[1].shape + + if required_value_shape != actual_value_shape: + preds = graph.ordered_preds_of(node) + pred_to_modify = preds[1] + + modified_value = deepcopy(pred_to_modify.output) + modified_value.shape = required_value_shape + + try: + np.broadcast_to(np.zeros(actual_value_shape), required_value_shape) + modified_value.is_encrypted = True + modified_value.dtype = node.output.dtype + modified_pred = Node.generic( + "broadcast_to", + [pred_to_modify.output], + modified_value, + np.broadcast_to, + kwargs={"shape": required_value_shape}, + ) + except Exception: # pylint: disable=broad-except + np.reshape(np.zeros(actual_value_shape), required_value_shape) + modified_pred = Node.generic( + "reshape", + [pred_to_modify.output], + modified_value, + np.reshape, + kwargs={"newshape": required_value_shape}, + ) + + nx_graph.add_edge(pred_to_modify, modified_pred, input_idx=0) + + nx_graph.remove_edge(pred_to_modify, node) + nx_graph.add_edge(modified_pred, node, input_idx=1) + + node.inputs[1] = modified_value + + @staticmethod + def _encrypt_clear_assignments(graph: Graph): + """ + Encrypt clear assignments. + + Args: + graph (Graph): + computation graph to transform + """ + + nx_graph = graph.graph + for node in list(nx_graph.nodes): + if node.operation == Operation.Generic and node.properties["name"] == "assign.static": + assigned_value = node.inputs[1] + if assigned_value.is_clear: + preds = graph.ordered_preds_of(node) + assigned_pred = preds[1] + + new_assigned_pred_value = deepcopy(assigned_value) + new_assigned_pred_value.is_encrypted = True + new_assigned_pred_value.dtype = preds[0].output.dtype + + zero = Node.generic( + "zeros", + [], + EncryptedScalar(new_assigned_pred_value.dtype), + lambda: np.zeros((), dtype=np.int64), + ) + + new_assigned_pred = Node.generic( + "add", + [assigned_pred.output, zero.output], + new_assigned_pred_value, + np.add, + ) + + nx_graph.remove_edge(preds[1], node) + + nx_graph.add_edge(preds[1], new_assigned_pred, input_idx=0) + nx_graph.add_edge(zero, new_assigned_pred, input_idx=1) + + nx_graph.add_edge(new_assigned_pred, node, input_idx=1) + @staticmethod def _tensorize_scalars_for_fhelinalg(graph: Graph): """ @@ -462,6 +576,8 @@ class GraphConverter: GraphConverter._update_bit_widths(graph) GraphConverter._offset_negative_lookup_table_inputs(graph) + GraphConverter._broadcast_assignments(graph) + GraphConverter._encrypt_clear_assignments(graph) GraphConverter._tensorize_scalars_for_fhelinalg(graph) from_elements_operations: Dict[OpResult, List[OpResult]] = {} diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index bcadb5448..505cf9eac 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -157,6 +157,9 @@ class NodeConverter: if name == "add": result = self._convert_add() + elif name == "assign.static": + result = self._convert_static_assignment() + elif name == "array": result = self._convert_array() @@ -716,6 +719,68 @@ class NodeConverter: ), ).result + def _convert_static_assignment(self) -> OpResult: + """ + Convert "assign.static" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + input_value = self.node.inputs[0] + input_shape = input_value.shape + + index = list(self.node.properties["kwargs"]["index"]) + + while len(index) < input_value.ndim: + index.append(slice(None, None, None)) + + output_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + + offsets = [] + sizes = [] + strides = [] + + for indexing_element, dimension_size in zip(index, input_shape): + + if isinstance(indexing_element, slice): + size = np.zeros(dimension_size)[indexing_element].shape[0] + stride = indexing_element.step if isinstance(indexing_element.step, int) else 1 + offset = ( + ( + indexing_element.start + if indexing_element.start >= 0 + else indexing_element.start + dimension_size + ) + if isinstance(indexing_element.start, int) + else (0 if stride > 0 else dimension_size - 1) + ) + + else: + size = 1 + stride = 1 + offset = int( + indexing_element if indexing_element >= 0 else indexing_element + dimension_size + ) + + offsets.append(offset) + sizes.append(size) + strides.append(stride) + + i64_type = IntegerType.get_signless(64) + return tensor.InsertSliceOp( + output_type, + self.preds[1], + self.preds[0], + [], + [], + [], + ArrayAttr.get([IntegerAttr.get(i64_type, value) for value in offsets]), + ArrayAttr.get([IntegerAttr.get(i64_type, value) for value in sizes]), + ArrayAttr.get([IntegerAttr.get(i64_type, value) for value in strides]), + ).result + def _convert_static_indexing(self) -> OpResult: """ Convert "index.static" node to its corresponding MLIR representation. diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index 5f2b6de00..e9423c15e 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -252,6 +252,11 @@ class Node: elements = [format_indexing_element(element) for element in index] return f"{predecessors[0]}[{', '.join(elements)}]" + if name == "assign.static": + index = self.properties["kwargs"]["index"] + elements = [format_indexing_element(element) for element in index] + return f"({predecessors[0]}[{', '.join(elements)}] = {predecessors[1]})" + if name == "concatenate": args = [f"({', '.join(predecessors)})"] else: @@ -292,7 +297,14 @@ class Node: assert_that(self.operation == Operation.Generic) name = self.properties["name"] - return name if name != "index.static" else self.format(["□"]) + + if name == "index.static": + name = self.format(["□"]) + + if name == "assign.static": + name = self.format(["□", "□"])[1:-1] + + return name @property def converted_to_table_lookup(self) -> bool: @@ -307,6 +319,7 @@ class Node: return self.operation == Operation.Generic and self.properties["name"] not in [ "add", "array", + "assign.static", "broadcast_to", "concatenate", "conv1d", diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index 606c4801e..b98aa0987 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -26,6 +26,9 @@ class Tracer: input_tracers: List["Tracer"] output: Value + # property to keep track of assignments + last_version: Optional["Tracer"] = None + # variable to control the behavior of __eq__ # so that it can be traced but still allow # using Tracers in dicts when not tracing @@ -71,6 +74,12 @@ class Tracer: if not isinstance(output_tracers, tuple): output_tracers = (output_tracers,) + output_tracer_list = list(output_tracers) + for i, output_tracer in enumerate(output_tracer_list): + if isinstance(output_tracer, Tracer) and output_tracer.last_version is not None: + output_tracer_list[i] = output_tracer.last_version + output_tracers = tuple(output_tracer_list) + sanitized_tracers = [] for tracer in output_tracers: if isinstance(tracer, Tracer): @@ -145,6 +154,9 @@ class Tracer: self.input_tracers = input_tracers self.output = computation.output + for i, tracer in enumerate(self.input_tracers): + self.input_tracers[i] = tracer if tracer.last_version is None else tracer.last_version + def __hash__(self) -> int: return id(self) @@ -671,6 +683,57 @@ class Tracer: ) return Tracer(computation, [self]) + def __setitem__( + self, + index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]], + value: Any, + ): + if not isinstance(index, tuple): + index = (index,) + + for indexing_element in index: + valid = isinstance(indexing_element, (int, np.integer, slice)) + + if isinstance(indexing_element, slice): + if ( + not ( + indexing_element.start is None + or isinstance(indexing_element.start, (int, np.integer)) + ) + or not ( + indexing_element.stop is None + or isinstance(indexing_element.stop, (int, np.integer)) + ) + or not ( + indexing_element.step is None + or isinstance(indexing_element.step, (int, np.integer)) + ) + ): + valid = False + + if not valid: + raise ValueError( + f"Assigning to '{format_indexing_element(indexing_element)}' is not supported" + ) + + np.zeros(self.output.shape)[index] = 1 + + def assign(x, value, index): + x[index] = value + return x + + sanitized_value = self.sanitize(value) + computation = Node.generic( + "assign.static", + [self.output, sanitized_value.output], + self.output, + assign, + kwargs={"index": index}, + ) + new_version = Tracer(computation, [self, sanitized_value]) + + self.last_version = new_version + @property def shape(self) -> Tuple[int, ...]: """ diff --git a/tests/execution/test_static_assignment.py b/tests/execution/test_static_assignment.py new file mode 100644 index 000000000..2f98076b7 --- /dev/null +++ b/tests/execution/test_static_assignment.py @@ -0,0 +1,530 @@ +""" +Tests of execution of static assignment operation. +""" + +import numpy as np +import pytest + +import concrete.numpy as cnp + + +def assignment_case_0(): + """ + Assignment test case. + """ + + shape = (3,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[:] = value + return x + + return shape, assign + + +def assignment_case_1(): + """ + Assignment test case. + """ + + shape = (3,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[0] = value + return x + + return shape, assign + + +def assignment_case_2(): + """ + Assignment test case. + """ + + shape = (3,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[1] = value + return x + + return shape, assign + + +def assignment_case_3(): + """ + Assignment test case. + """ + + shape = (3,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[2] = value + return x + + return shape, assign + + +def assignment_case_4(): + """ + Assignment test case. + """ + + shape = (5,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[0:3] = value + return x + + return shape, assign + + +def assignment_case_5(): + """ + Assignment test case. + """ + + shape = (5,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[1:4] = value + return x + + return shape, assign + + +def assignment_case_6(): + """ + Assignment test case. + """ + + shape = (5,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[1:4:2] = value + return x + + return shape, assign + + +def assignment_case_7(): + """ + Assignment test case. + """ + + shape = (10,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[::2] = value + return x + + return shape, assign + + +def assignment_case_8(): + """ + Assignment test case. + """ + + shape = (5,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[2:0:-1] = value + return x + + return shape, assign + + +def assignment_case_9(): + """ + Assignment test case. + """ + + shape = (5,) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[4:0:-2] = value + return x + + return shape, assign + + +def assignment_case_10(): + """ + Assignment test case. + """ + + shape = (5,) + value = np.random.randint(0, 2**7, size=(3,)) + + def assign(x): + x[1:4] = value + return x + + return shape, assign + + +def assignment_case_11(): + """ + Assignment test case. + """ + + shape = (5,) + value = np.random.randint(0, 2**7, size=(3,)) + + def assign(x): + x[4:1:-1] = value + return x + + return shape, assign + + +def assignment_case_12(): + """ + Assignment test case. + """ + + shape = (10,) + value = np.random.randint(0, 2**7, size=(3,)) + + def assign(x): + x[1:7:2] = value + return x + + return shape, assign + + +def assignment_case_13(): + """ + Assignment test case. + """ + + shape = (10,) + value = np.random.randint(0, 2**7, size=(3,)) + + def assign(x): + x[7:1:-2] = value + return x + + return shape, assign + + +def assignment_case_14(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[0, 0] = value + return x + + return shape, assign + + +def assignment_case_15(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[3, 1] = value + return x + + return shape, assign + + +def assignment_case_16(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[0] = value + return x + + return shape, assign + + +def assignment_case_17(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(4,)) + + def assign(x): + x[0] = value + return x + + return shape, assign + + +def assignment_case_18(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(5,)) + + def assign(x): + x[:, 0] = value + return x + + return shape, assign + + +def assignment_case_19(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(5,)) + + def assign(x): + x[:, 1] = value + return x + + return shape, assign + + +def assignment_case_20(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[0:3, :] = value + return x + + return shape, assign + + +def assignment_case_21(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(3, 4)) + + def assign(x): + x[0:3, :] = value + return x + + return shape, assign + + +def assignment_case_22(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(4,)) + + def assign(x): + x[0:3, :] = value + return x + + return shape, assign + + +def assignment_case_23(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(3,)) + + def assign(x): + x[0:3, 1:4] = value + return x + + return shape, assign + + +def assignment_case_24(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(3, 3)) + + def assign(x): + x[0:3, 1:4] = value + return x + + return shape, assign + + +def assignment_case_25(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(3, 3)) + + def assign(x): + x[4:1:-1, 3:0:-1] = value + return x + + return shape, assign + + +def assignment_case_26(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(3,)) + + def assign(x): + x[3:0:-1, 0] = value + return x + + return shape, assign + + +def assignment_case_27(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=(2,)) + + def assign(x): + x[0, 1:3] = value + return x + + return shape, assign + + +def assignment_case_28(): + """ + Assignment test case. + """ + + shape = (5, 4) + value = np.random.randint(0, 2**7, size=()) + + def assign(x): + x[2:4, 1:3] = value + return x + + return shape, assign + + +@pytest.mark.parametrize( + "shape,function", + [ + pytest.param(*assignment_case_0()), + pytest.param(*assignment_case_1()), + pytest.param(*assignment_case_2()), + pytest.param(*assignment_case_3()), + pytest.param(*assignment_case_4()), + pytest.param(*assignment_case_5()), + pytest.param(*assignment_case_6()), + pytest.param(*assignment_case_7()), + pytest.param(*assignment_case_8()), + pytest.param(*assignment_case_9()), + pytest.param(*assignment_case_10()), + pytest.param(*assignment_case_11()), + pytest.param(*assignment_case_12()), + pytest.param(*assignment_case_13()), + pytest.param(*assignment_case_14()), + pytest.param(*assignment_case_15()), + pytest.param(*assignment_case_16()), + pytest.param(*assignment_case_17()), + pytest.param(*assignment_case_18()), + pytest.param(*assignment_case_19()), + pytest.param(*assignment_case_20()), + pytest.param(*assignment_case_21()), + pytest.param(*assignment_case_22()), + pytest.param(*assignment_case_23()), + pytest.param(*assignment_case_24()), + pytest.param(*assignment_case_25()), + pytest.param(*assignment_case_26()), + pytest.param(*assignment_case_27()), + pytest.param(*assignment_case_28()), + ], +) +def test_static_assignment(shape, function, helpers): + """ + Test static assignment. + """ + + configuration = helpers.configuration() + compiler = cnp.Compiler(function, {"x": "encrypted"}) + + inputset = [np.random.randint(0, 2**7, size=shape) for _ in range(100)] + circuit = compiler.compile(inputset, configuration) + + sample = np.random.randint(0, 2**7, size=shape) + helpers.check_execution(circuit, function, sample) + + +def test_bad_static_assignment(helpers): + """ + Test static assingment with bad parameters. + """ + + configuration = helpers.configuration() + + # with float + # ---------- + + def f(x): + x[1.5] = 0 + return x + + compiler = cnp.Compiler(f, {"x": "encrypted"}) + + inputset = [np.random.randint(0, 2**3, size=(3,)) for _ in range(100)] + with pytest.raises(ValueError) as excinfo: + compiler.compile(inputset, configuration) + + assert str(excinfo.value) == "Assigning to '1.5' is not supported" + + # with bad slice + # -------------- + + def g(x): + x[slice(1.5, 2.5, None)] = 0 + return x + + compiler = cnp.Compiler(g, {"x": "encrypted"}) + + inputset = [np.random.randint(0, 2**3, size=(3,)) for _ in range(100)] + with pytest.raises(ValueError) as excinfo: + compiler.compile(inputset, configuration) + + assert str(excinfo.value) == "Assigning to '1.5:2.5' is not supported" diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index b826a090f..db27379b7 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -11,6 +11,15 @@ import concrete.onnx as connx # pylint: disable=line-too-long +def assign(x): + """ + Simple assignment to a vector. + """ + + x[0] = 0 + return x + + @pytest.mark.parametrize( "function,encryption_statuses,inputset,expected_error,expected_message", [ @@ -386,6 +395,24 @@ Function you are trying to compile cannot be converted to MLIR ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted broadcasting is supported return %1 + """, # noqa: E501 + ), + pytest.param( + assign, + {"x": "clear"}, + [np.random.randint(0, 2, size=(3,)) for _ in range(100)], + RuntimeError, + """ + +Function you are trying to compile cannot be converted to MLIR + +%0 = x # ClearTensor +%1 = 0 # ClearScalar +%2 = (%0[0] = %1) # ClearTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only assignment to encrypted tensors are supported +return %2 + + """, # noqa: E501 ), ], diff --git a/tests/representation/test_node.py b/tests/representation/test_node.py index 1ff27f38e..b5a45515e 100644 --- a/tests/representation/test_node.py +++ b/tests/representation/test_node.py @@ -206,6 +206,17 @@ def test_node_bad_call(node, args, expected_error, expected_message): ["%0", "%1", "%2", "%3"], "array([[%0, %1], [%2, %3]])", ), + pytest.param( + Node.generic( + name="assign.static", + inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))], + output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)), + operation=lambda *args: args, + kwargs={"index": (1, 2)}, + ), + ["%0", "%1"], + "(%0[1, 2] = %1)", + ), ], ) def test_node_format(node, predecessors, expected_result): @@ -253,6 +264,26 @@ def test_node_format(node, predecessors, expected_result): ), "concatenate", ), + pytest.param( + Node.generic( + name="index.static", + inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))], + output=EncryptedTensor(UnsignedInteger(3), shape=()), + operation=lambda *args: args, + kwargs={"index": (1, 2)}, + ), + "□[1, 2]", + ), + pytest.param( + Node.generic( + name="assign.static", + inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3, 4))], + output=EncryptedTensor(UnsignedInteger(3), shape=(3, 4)), + operation=lambda *args: args, + kwargs={"index": (1, 2)}, + ), + "□[1, 2] = □", + ), ], ) def test_node_label(node, expected_result):