diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index bae096d3e..bad5a9b13 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -473,6 +473,7 @@ class GraphConverter: for arg_num, node in graph.input_nodes.items(): ir_to_mlir[node] = sanitized_args[arg_num] + constant_cache = {} for node in nx.topological_sort(graph.graph): if node.operation == Operation.Input: continue @@ -483,6 +484,7 @@ class GraphConverter: graph, node, preds, + constant_cache, direct_replacements, ) ir_to_mlir[node] = node_converter.convert() diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index 036bd82fb..1e088327c 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -51,6 +51,7 @@ class NodeConverter: all_of_the_inputs_are_tensors: bool one_of_the_inputs_is_a_tensor: bool + constant_cache: Dict[Tuple[Type, Attribute], OpResult] direct_replacements: Dict[str, str] # pylint: enable=too-many-instance-attributes @@ -112,6 +113,7 @@ class NodeConverter: graph: Graph, node: Node, preds: List[OpResult], + constant_cache: Dict[Tuple[Type, Attribute], OpResult], direct_replacements: Dict[str, str], ): self.ctx = ctx @@ -132,6 +134,7 @@ class NodeConverter: else: self.one_of_the_inputs_is_a_tensor = True + self.constant_cache = constant_cache self.direct_replacements = direct_replacements def convert(self) -> OpResult: @@ -380,7 +383,7 @@ class NodeConverter: attr = Attribute.parse(f"dense<{str(data.tolist())}> : {resulting_type}") - return arith.ConstantOp(resulting_type, attr).result + return self._create_constant(resulting_type, attr).result def _convert_conv1d(self) -> OpResult: """ @@ -556,7 +559,7 @@ class NodeConverter: constant_attr = IntegerAttr.get(constant_type, 1) zero = fhe.ZeroEintOp(resulting_type).result - one = arith.ConstantOp(constant_type, constant_attr).result + one = self._create_constant(constant_type, constant_attr).result result = fhe.AddEintIntOp(resulting_type, zero, one).result else: @@ -569,7 +572,7 @@ class NodeConverter: constant_attr = Attribute.parse(f"dense<[1]> : {constant_type}") zeros = fhe.ZeroTensorOp(resulting_type).result - ones = arith.ConstantOp(constant_type, constant_attr).result + ones = self._create_constant(constant_type, constant_attr).result result = fhelinalg.AddEintIntOp(resulting_type, zeros, ones).result @@ -718,7 +721,7 @@ class NodeConverter: for value, dimension_size in zip(index, input_shape): value = int(value) attr = IntegerAttr.get(index_type, value if value >= 0 else value + dimension_size) - indices.append(arith.ConstantOp(index_type, attr).result) + indices.append(self._create_constant(index_type, attr).result) return tensor.ExtractOp(output_type, self.preds[0], indices).result offsets = [] @@ -920,7 +923,7 @@ class NodeConverter: lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx)) lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx) - lut = arith.ConstantOp(lut_type, lut_attr).result + lut = self._create_constant(lut_type, lut_attr).result resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) pred = self.preds[variable_input_index] @@ -940,7 +943,7 @@ class NodeConverter: resulting_type, pred, lut, - arith.ConstantOp(map_type, map_attr).result, + self._create_constant(map_type, map_attr).result, ).result else: result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result @@ -978,3 +981,10 @@ class NodeConverter: result = fhe.ZeroTensorOp(resulting_type).result return result + + def _create_constant(self, mlir_type: Type, mlir_attribute: Attribute): + result = self.constant_cache.get((mlir_type, mlir_attribute)) + if result is None: + result = arith.ConstantOp(mlir_type, mlir_attribute) + self.constant_cache[(mlir_type, mlir_attribute)] = result + return result diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 8546a5132..f6db931f3 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -412,4 +412,43 @@ def test_graph_converter_bad_convert( helpers.check_str(expected_message, str(excinfo.value)) +@pytest.mark.parametrize( + "function,inputset,expected_mlir", + [ + pytest.param( + lambda x: 1 + cnp.LookupTable([4, 1, 2, 3])[x] + cnp.LookupTable([4, 1, 2, 3])[x + 1], + range(3), + """ + +module { + func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { + %c1_i4 = arith.constant 1 : i4 + %cst = arith.constant dense<[4, 1, 2, 3, 3, 3, 3, 3]> : tensor<8xi64> + %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3> + %1 = "FHE.add_eint_int"(%arg0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3> + %2 = "FHE.add_eint_int"(%0, %c1_i4) : (!FHE.eint<3>, i4) -> !FHE.eint<3> + %3 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<3> + %4 = "FHE.add_eint"(%2, %3) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> + return %4 : !FHE.eint<3> + } +} + + """, # noqa: E501 + # Notice that there is only a single 1 and a single table cst above + ), + ], +) +def test_constant_cache(function, inputset, expected_mlir, helpers): + """ + Test caching MLIR constants. + """ + + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, {"x": "encrypted"}) + circuit = compiler.compile(inputset, configuration) + + helpers.check_str(expected_mlir, circuit.mlir) + + # pylint: enable=line-too-long