diff --git a/concrete/common/mlir/converters.py b/concrete/common/mlir/converters.py index 915b9dc4c..fe691c1d6 100644 --- a/concrete/common/mlir/converters.py +++ b/concrete/common/mlir/converters.py @@ -181,7 +181,18 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info): x_node = preds[0] x = ir_to_mlir_node[x_node] - table = additional_conversion_info["tables"][node] + tables = additional_conversion_info["tables"][node] + + # TODO: #559 adapt the code to support multi TLUs + # This cannot be reached today as compilation fails if the intermediate values are not all + # scalars + if len(tables) > 1: # pragma: no cover + raise RuntimeError( + "MLIR conversion currently does not support multiple test vectors for LUT" + ) + + table = tables[0][0] + out_dtype = cast(Integer, node.outputs[0].dtype) # Create table dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx) diff --git a/concrete/numpy/np_mlir_converter.py b/concrete/numpy/np_mlir_converter.py index 320a76a4c..8e2848fc7 100644 --- a/concrete/numpy/np_mlir_converter.py +++ b/concrete/numpy/np_mlir_converter.py @@ -1,14 +1,71 @@ """Numpy-specific MLIR converter.""" -from typing import Any, Dict +import math +from collections import defaultdict +from itertools import product +from typing import Any, DefaultDict, Dict, List, Tuple import numpy +from ..common.debugging import assert_true from ..common.mlir.mlir_converter import MLIRConverter from ..common.operator_graph import OPGraph from ..common.representation.intermediate import UnivariateFunction +class HashableNPArray: + """Class to easily manipulate numpy arrays for hashing. + + Note that the hash behavior won't work if the array is modified after being hashed, as it will + have been hashed to a certain value and the new array content will be hashed to a different one. + """ + + array: numpy.ndarray + + def __init__(self, array: numpy.ndarray) -> None: + self.array = array + + def __hash__(self) -> int: + return hash(self.array.tobytes()) + + def __eq__(self, other: object) -> bool: + return isinstance(other, HashableNPArray) and numpy.array_equal(self.array, other.array) + + +def generate_deduplicated_tables( + node: UnivariateFunction, +) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: + """Deduplicate the tables for the different cells of a tensor if needed. + + Args: + node (UnivariateFunction): the node for which to deduplicate the table + + Returns: + Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose + first element is a table and the second element is a list of tuples indicating which + cells in the tensor will use that table. + """ + # This is the tensor containing the tables for each cell of the tensor for node + node_complete_table = numpy.concatenate( + tuple(numpy.expand_dims(array, -1) for array in node.get_table()), axis=-1 + ) + + all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1])) + tables_to_cell_idx: DefaultDict[HashableNPArray, List[Tuple[int, ...]]] = defaultdict(list) + idx: Tuple[int, ...] + all_idx_set = set() + for idx in all_cells_idx: + hashable_array = HashableNPArray(node_complete_table[idx]) + tables_to_cell_idx[hashable_array].append(idx) + all_idx_set.add(idx) + + assert_true(len(all_idx_set) == math.prod(node_complete_table.shape[:-1])) + + return tuple( + (hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items() + ) + + class NPMLIRConverter(MLIRConverter): """Numpy-specific MLIR converter.""" @@ -28,7 +85,7 @@ class NPMLIRConverter(MLIRConverter): # Disable numpy warnings during conversion to avoid issues during TLU generation with numpy.errstate(all="ignore"): additional_conversion_info["tables"] = { - node: node.get_table() + node: generate_deduplicated_tables(node) for node in op_graph.graph.nodes() if isinstance(node, UnivariateFunction) } diff --git a/tests/numpy/test_np_mlir_converter.py b/tests/numpy/test_np_mlir_converter.py new file mode 100644 index 000000000..9edeed8e8 --- /dev/null +++ b/tests/numpy/test_np_mlir_converter.py @@ -0,0 +1,106 @@ +"""Test file for numpy mlir converter""" + +import math + +import numpy +import pytest + +import concrete.numpy as hnp +from concrete.common.representation.intermediate import UnivariateFunction +from concrete.numpy.np_mlir_converter import generate_deduplicated_tables + + +def multi_tlu_func(x, cst): + """Multi TLU function""" + y = x + cst + return y.astype(numpy.int32) + + +RESNET_BIGGEST_SHAPE = (64, 112, 112) +RESNET_BIGGEST_SIZE = math.prod(RESNET_BIGGEST_SHAPE) + + +@pytest.mark.parametrize( + "function,expected_number_of_tables", + [ + ( + lambda x: multi_tlu_func(x, numpy.zeros(RESNET_BIGGEST_SHAPE, dtype=numpy.float64)), + 1, + ), + ( + lambda x: multi_tlu_func( + x, + numpy.arange(RESNET_BIGGEST_SIZE, dtype=numpy.float64).reshape( + RESNET_BIGGEST_SHAPE + ), + ), + RESNET_BIGGEST_SIZE, + ), + ], +) +def test_generate_deduplicated_tables( + function, expected_number_of_tables, default_compilation_configuration +): + """Test function for generate_deduplicated_tables""" + op_graph = hnp.compile_numpy_function_into_op_graph( + function, + {"x": hnp.EncryptedTensor(hnp.Integer(7, False), RESNET_BIGGEST_SHAPE)}, + ((i * numpy.ones(RESNET_BIGGEST_SHAPE, dtype=numpy.int32),) for i in range(128)), + default_compilation_configuration, + ) + + univariate_function_nodes = [ + node for node in op_graph.graph.nodes() if isinstance(node, UnivariateFunction) + ] + + assert len(univariate_function_nodes) == 1 + + tlu_node = univariate_function_nodes[0] + + deduplication_result = generate_deduplicated_tables(tlu_node) + + assert len(deduplication_result) == expected_number_of_tables + + +def test_deduplicated_tables_correctness(default_compilation_configuration): + """Check the deduplicated tables are the expected ones""" + + tensor_shape = (2, 2) + + op_graph = hnp.compile_numpy_function_into_op_graph( + lambda x: multi_tlu_func(x, numpy.arange(4, dtype=numpy.float64).reshape(tensor_shape)), + {"x": hnp.EncryptedTensor(hnp.Integer(2, False), tensor_shape)}, + ((i * numpy.ones(tensor_shape, dtype=numpy.int32),) for i in range(4)), + default_compilation_configuration, + ) + + univariate_function_nodes = [ + node for node in op_graph.graph.nodes() if isinstance(node, UnivariateFunction) + ] + + assert len(univariate_function_nodes) == 1 + + tlu_node = univariate_function_nodes[0] + + deduplication_result = generate_deduplicated_tables(tlu_node) + + expected_result = tuple( + ( + numpy.arange(i, 4 + i, dtype=numpy.int32), + [ + numpy.unravel_index(i, tensor_shape), + ], + ) + for i in range(4) + ) + + assert len(deduplication_result) == len(expected_result) + for computed_array, computed_idx in deduplication_result: + for expected_array, expected_idx in expected_result: + if numpy.array_equal(computed_array, expected_array) and computed_idx == expected_idx: + break + else: + raise AssertionError( + f"Could not find {(computed_array, computed_idx)} " + f"in expected_result: {expected_result}" + )