mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add table deduplication to NPMLIRConverter
closes #560 closes #561
This commit is contained in:
@@ -181,7 +181,18 @@ def apply_lut(node, preds, ir_to_mlir_node, ctx, additional_conversion_info):
|
||||
|
||||
x_node = preds[0]
|
||||
x = ir_to_mlir_node[x_node]
|
||||
table = additional_conversion_info["tables"][node]
|
||||
tables = additional_conversion_info["tables"][node]
|
||||
|
||||
# TODO: #559 adapt the code to support multi TLUs
|
||||
# This cannot be reached today as compilation fails if the intermediate values are not all
|
||||
# scalars
|
||||
if len(tables) > 1: # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"MLIR conversion currently does not support multiple test vectors for LUT"
|
||||
)
|
||||
|
||||
table = tables[0][0]
|
||||
|
||||
out_dtype = cast(Integer, node.outputs[0].dtype)
|
||||
# Create table
|
||||
dense_elem = DenseElementsAttr.get(np.array(table, dtype=np.uint64), context=ctx)
|
||||
|
||||
@@ -1,14 +1,71 @@
|
||||
"""Numpy-specific MLIR converter."""
|
||||
|
||||
from typing import Any, Dict
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from itertools import product
|
||||
from typing import Any, DefaultDict, Dict, List, Tuple
|
||||
|
||||
import numpy
|
||||
|
||||
from ..common.debugging import assert_true
|
||||
from ..common.mlir.mlir_converter import MLIRConverter
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import UnivariateFunction
|
||||
|
||||
|
||||
class HashableNPArray:
|
||||
"""Class to easily manipulate numpy arrays for hashing.
|
||||
|
||||
Note that the hash behavior won't work if the array is modified after being hashed, as it will
|
||||
have been hashed to a certain value and the new array content will be hashed to a different one.
|
||||
"""
|
||||
|
||||
array: numpy.ndarray
|
||||
|
||||
def __init__(self, array: numpy.ndarray) -> None:
|
||||
self.array = array
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.array.tobytes())
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, HashableNPArray) and numpy.array_equal(self.array, other.array)
|
||||
|
||||
|
||||
def generate_deduplicated_tables(
|
||||
node: UnivariateFunction,
|
||||
) -> Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]:
|
||||
"""Deduplicate the tables for the different cells of a tensor if needed.
|
||||
|
||||
Args:
|
||||
node (UnivariateFunction): the node for which to deduplicate the table
|
||||
|
||||
Returns:
|
||||
Tuple[Tuple[numpy.ndarray, List[Tuple[int, ...]]], ...]: A tuple containing tuples whose
|
||||
first element is a table and the second element is a list of tuples indicating which
|
||||
cells in the tensor will use that table.
|
||||
"""
|
||||
# This is the tensor containing the tables for each cell of the tensor for node
|
||||
node_complete_table = numpy.concatenate(
|
||||
tuple(numpy.expand_dims(array, -1) for array in node.get_table()), axis=-1
|
||||
)
|
||||
|
||||
all_cells_idx = product(*tuple(range(max_val) for max_val in node_complete_table.shape[:-1]))
|
||||
tables_to_cell_idx: DefaultDict[HashableNPArray, List[Tuple[int, ...]]] = defaultdict(list)
|
||||
idx: Tuple[int, ...]
|
||||
all_idx_set = set()
|
||||
for idx in all_cells_idx:
|
||||
hashable_array = HashableNPArray(node_complete_table[idx])
|
||||
tables_to_cell_idx[hashable_array].append(idx)
|
||||
all_idx_set.add(idx)
|
||||
|
||||
assert_true(len(all_idx_set) == math.prod(node_complete_table.shape[:-1]))
|
||||
|
||||
return tuple(
|
||||
(hashable_array.array, indices) for hashable_array, indices in tables_to_cell_idx.items()
|
||||
)
|
||||
|
||||
|
||||
class NPMLIRConverter(MLIRConverter):
|
||||
"""Numpy-specific MLIR converter."""
|
||||
|
||||
@@ -28,7 +85,7 @@ class NPMLIRConverter(MLIRConverter):
|
||||
# Disable numpy warnings during conversion to avoid issues during TLU generation
|
||||
with numpy.errstate(all="ignore"):
|
||||
additional_conversion_info["tables"] = {
|
||||
node: node.get_table()
|
||||
node: generate_deduplicated_tables(node)
|
||||
for node in op_graph.graph.nodes()
|
||||
if isinstance(node, UnivariateFunction)
|
||||
}
|
||||
|
||||
106
tests/numpy/test_np_mlir_converter.py
Normal file
106
tests/numpy/test_np_mlir_converter.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Test file for numpy mlir converter"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as hnp
|
||||
from concrete.common.representation.intermediate import UnivariateFunction
|
||||
from concrete.numpy.np_mlir_converter import generate_deduplicated_tables
|
||||
|
||||
|
||||
def multi_tlu_func(x, cst):
|
||||
"""Multi TLU function"""
|
||||
y = x + cst
|
||||
return y.astype(numpy.int32)
|
||||
|
||||
|
||||
RESNET_BIGGEST_SHAPE = (64, 112, 112)
|
||||
RESNET_BIGGEST_SIZE = math.prod(RESNET_BIGGEST_SHAPE)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,expected_number_of_tables",
|
||||
[
|
||||
(
|
||||
lambda x: multi_tlu_func(x, numpy.zeros(RESNET_BIGGEST_SHAPE, dtype=numpy.float64)),
|
||||
1,
|
||||
),
|
||||
(
|
||||
lambda x: multi_tlu_func(
|
||||
x,
|
||||
numpy.arange(RESNET_BIGGEST_SIZE, dtype=numpy.float64).reshape(
|
||||
RESNET_BIGGEST_SHAPE
|
||||
),
|
||||
),
|
||||
RESNET_BIGGEST_SIZE,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_generate_deduplicated_tables(
|
||||
function, expected_number_of_tables, default_compilation_configuration
|
||||
):
|
||||
"""Test function for generate_deduplicated_tables"""
|
||||
op_graph = hnp.compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": hnp.EncryptedTensor(hnp.Integer(7, False), RESNET_BIGGEST_SHAPE)},
|
||||
((i * numpy.ones(RESNET_BIGGEST_SHAPE, dtype=numpy.int32),) for i in range(128)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
univariate_function_nodes = [
|
||||
node for node in op_graph.graph.nodes() if isinstance(node, UnivariateFunction)
|
||||
]
|
||||
|
||||
assert len(univariate_function_nodes) == 1
|
||||
|
||||
tlu_node = univariate_function_nodes[0]
|
||||
|
||||
deduplication_result = generate_deduplicated_tables(tlu_node)
|
||||
|
||||
assert len(deduplication_result) == expected_number_of_tables
|
||||
|
||||
|
||||
def test_deduplicated_tables_correctness(default_compilation_configuration):
|
||||
"""Check the deduplicated tables are the expected ones"""
|
||||
|
||||
tensor_shape = (2, 2)
|
||||
|
||||
op_graph = hnp.compile_numpy_function_into_op_graph(
|
||||
lambda x: multi_tlu_func(x, numpy.arange(4, dtype=numpy.float64).reshape(tensor_shape)),
|
||||
{"x": hnp.EncryptedTensor(hnp.Integer(2, False), tensor_shape)},
|
||||
((i * numpy.ones(tensor_shape, dtype=numpy.int32),) for i in range(4)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
univariate_function_nodes = [
|
||||
node for node in op_graph.graph.nodes() if isinstance(node, UnivariateFunction)
|
||||
]
|
||||
|
||||
assert len(univariate_function_nodes) == 1
|
||||
|
||||
tlu_node = univariate_function_nodes[0]
|
||||
|
||||
deduplication_result = generate_deduplicated_tables(tlu_node)
|
||||
|
||||
expected_result = tuple(
|
||||
(
|
||||
numpy.arange(i, 4 + i, dtype=numpy.int32),
|
||||
[
|
||||
numpy.unravel_index(i, tensor_shape),
|
||||
],
|
||||
)
|
||||
for i in range(4)
|
||||
)
|
||||
|
||||
assert len(deduplication_result) == len(expected_result)
|
||||
for computed_array, computed_idx in deduplication_result:
|
||||
for expected_array, expected_idx in expected_result:
|
||||
if numpy.array_equal(computed_array, expected_array) and computed_idx == expected_idx:
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Could not find {(computed_array, computed_idx)} "
|
||||
f"in expected_result: {expected_result}"
|
||||
)
|
||||
Reference in New Issue
Block a user