feat: add table deduplication to NPMLIRConverter

closes #560
closes #561
This commit is contained in:
Arthur Meyre
2021-10-19 16:14:46 +02:00
parent 76d6f1e1f1
commit fbfaeb2b17
3 changed files with 177 additions and 3 deletions

View File

@@ -181,7 +181,18 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
x_node = preds[0]
x = ir_to_mlir_node[x_node]
table = additional_conversion_info["tables"][node]
tables = additional_conversion_info["tables"][node]
# TODO: #559 adapt the code to support multi TLUs
# This cannot be reached today as compilation fails if the intermediate values are not all
# scalars
if len(tables) > 1: # pragma: no cover
raise RuntimeError(
"MLIR conversion currently does not support multiple test vectors for LUT"
)
table = tables[0][0]
out_dtype = cast(Integer, node.outputs[0].dtype)
# Create table
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)

View File

@@ -1,14 +1,71 @@
"""Numpy-specific MLIR converter."""
from typing import Any, Dict
import math
from collections import defaultdict
from itertools import product
from typing import Any, DefaultDict, Dict, List, Tuple
import numpy
from ..common.debugging import assert_true
from ..common.mlir.mlir_converter import MLIRConverter
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import UnivariateFunction
class HashableNPArray:
"""Class to easily manipulate numpy arrays for hashing.
Note that the hash behavior won't work if the array is modified after being hashed, as it will
have been hashed to a certain value and the new array content will be hashed to a different one.
"""
array: numpy.ndarray
def __init__(self, array: numpy.ndarray) -> None:
self.array = array
def __hash__(self) -> int:
return hash(self.array.tobytes())
def __eq__(self, other: object) -> bool:
return isinstance(other, HashableNPArray) and numpy.array_equal(self.array, other.array)
def generate_deduplicated_tables(
node: UnivariateFunction,
) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
"""Deduplicate the tables for the different cells of a tensor if needed.
Args:
node (UnivariateFunction): the node for which to deduplicate the table
Returns:
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose
first element is a table and the second element is a list of tuples indicating which
cells in the tensor will use that table.
"""
# This is the tensor containing the tables for each cell of the tensor for node
node_complete_table = numpy.concatenate(
tuple(numpy.expand_dims(array, -1) for array in node.get_table()), axis=-1
)
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
tables_to_cell_idx: DefaultDict[HashableNPArray, List[Tuple[int, ...]]] = defaultdict(list)
idx: Tuple[int, ...]
all_idx_set = set()
for idx in all_cells_idx:
hashable_array = HashableNPArray(node_complete_table[idx])
tables_to_cell_idx[hashable_array].append(idx)
all_idx_set.add(idx)
assert_true(len(all_idx_set) == math.prod(node_complete_table.shape[:-1]))
return tuple(
(hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items()
)
class NPMLIRConverter(MLIRConverter):
"""Numpy-specific MLIR converter."""
@@ -28,7 +85,7 @@ class NPMLIRConverter(MLIRConverter):
# Disable numpy warnings during conversion to avoid issues during TLU generation
with numpy.errstate(all="ignore"):
additional_conversion_info["tables"] = {
node: node.get_table()
node: generate_deduplicated_tables(node)
for node in op_graph.graph.nodes()
if isinstance(node, UnivariateFunction)
}

View File

@@ -0,0 +1,106 @@
"""Test file for numpy mlir converter"""
import math
import numpy
import pytest
import concrete.numpy as hnp
from concrete.common.representation.intermediate import UnivariateFunction
from concrete.numpy.np_mlir_converter import generate_deduplicated_tables
def multi_tlu_func(x, cst):
"""Multi TLU function"""
y = x + cst
return y.astype(numpy.int32)
RESNET_BIGGEST_SHAPE = (64, 112, 112)
RESNET_BIGGEST_SIZE = math.prod(RESNET_BIGGEST_SHAPE)
@pytest.mark.parametrize(
"function,expected_number_of_tables",
[
(
lambda x: multi_tlu_func(x, numpy.zeros(RESNET_BIGGEST_SHAPE, dtype=numpy.float64)),
1,
),
(
lambda x: multi_tlu_func(
x,
numpy.arange(RESNET_BIGGEST_SIZE, dtype=numpy.float64).reshape(
RESNET_BIGGEST_SHAPE
),
),
RESNET_BIGGEST_SIZE,
),
],
)
def test_generate_deduplicated_tables(
function, expected_number_of_tables, default_compilation_configuration
):
"""Test function for generate_deduplicated_tables"""
op_graph = hnp.compile_numpy_function_into_op_graph(
function,
{"x": hnp.EncryptedTensor(hnp.Integer(7, False), RESNET_BIGGEST_SHAPE)},
((i * numpy.ones(RESNET_BIGGEST_SHAPE, dtype=numpy.int32),) for i in range(128)),
default_compilation_configuration,
)
univariate_function_nodes = [
node for node in op_graph.graph.nodes() if isinstance(node, UnivariateFunction)
]
assert len(univariate_function_nodes) == 1
tlu_node = univariate_function_nodes[0]
deduplication_result = generate_deduplicated_tables(tlu_node)
assert len(deduplication_result) == expected_number_of_tables
def test_deduplicated_tables_correctness(default_compilation_configuration):
"""Check the deduplicated tables are the expected ones"""
tensor_shape = (2, 2)
op_graph = hnp.compile_numpy_function_into_op_graph(
lambda x: multi_tlu_func(x, numpy.arange(4, dtype=numpy.float64).reshape(tensor_shape)),
{"x": hnp.EncryptedTensor(hnp.Integer(2, False), tensor_shape)},
((i * numpy.ones(tensor_shape, dtype=numpy.int32),) for i in range(4)),
default_compilation_configuration,
)
univariate_function_nodes = [
node for node in op_graph.graph.nodes() if isinstance(node, UnivariateFunction)
]
assert len(univariate_function_nodes) == 1
tlu_node = univariate_function_nodes[0]
deduplication_result = generate_deduplicated_tables(tlu_node)
expected_result = tuple(
(
numpy.arange(i, 4 + i, dtype=numpy.int32),
[
numpy.unravel_index(i, tensor_shape),
],
)
for i in range(4)
)
assert len(deduplication_result) == len(expected_result)
for computed_array, computed_idx in deduplication_result:
for expected_array, expected_idx in expected_result:
if numpy.array_equal(computed_array, expected_array) and computed_idx == expected_idx:
break
else:
raise AssertionError(
f"Could not find {(computed_array, computed_idx)} "
f"in expected_result: {expected_result}"
)