mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(extensions): add and test support for negative direct table lookups
This commit is contained in:
@@ -62,7 +62,7 @@ class LookupTable:
|
||||
|
||||
@staticmethod
|
||||
def _check_index_out_of_range(x, table):
|
||||
if x < 0 or x >= len(table):
|
||||
if not -len(table) <= x < len(table):
|
||||
raise ValueError(
|
||||
f"Lookup table with {len(table)} entries cannot be indexed with {x} "
|
||||
f"(you should check your inputset)",
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Utilities for MLIR conversion."""
|
||||
from typing import Dict, List, Optional, cast
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
value_is_clear_scalar_integer,
|
||||
value_is_clear_tensor_integer,
|
||||
@@ -16,7 +15,7 @@ from ..debugging import format_operation_graph
|
||||
from ..debugging.custom_assert import assert_not_reached, assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate
|
||||
from ..representation.intermediate import GenericFunction, IntermediateNode
|
||||
from ..representation.intermediate import IntermediateNode
|
||||
|
||||
# TODO: should come from compiler, through an API, #402
|
||||
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
|
||||
@@ -210,24 +209,3 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
)
|
||||
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
|
||||
|
||||
def extend_direct_lookup_tables(op_graph: OPGraph):
|
||||
"""Extend direct lookup tables to the maximum length the input bit width can support.
|
||||
|
||||
Args:
|
||||
op_graph: graph to update lookup tables for
|
||||
"""
|
||||
for node in op_graph.graph.nodes:
|
||||
if isinstance(node, GenericFunction) and node.op_name == "TLU":
|
||||
table = node.op_kwargs["table"]
|
||||
bit_width = cast(Integer, node.inputs[0].dtype).bit_width
|
||||
expected_length = 2 ** bit_width
|
||||
|
||||
# TODO: remove no cover once the table length workaround is removed
|
||||
# (https://github.com/zama-ai/concretefhe-internal/issues/359)
|
||||
if len(table) > expected_length: # pragma: no cover
|
||||
node.op_kwargs["table"] = table[:expected_length]
|
||||
else:
|
||||
repeat = expected_length // len(table)
|
||||
node.op_kwargs["table"] = (table * repeat)[:expected_length]
|
||||
|
||||
@@ -17,7 +17,6 @@ from ..common.debugging.custom_assert import assert_true
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.mlir.utils import (
|
||||
check_graph_values_compatibility_with_mlir,
|
||||
extend_direct_lookup_tables,
|
||||
update_bit_width_for_mlir,
|
||||
)
|
||||
from ..common.operator_graph import OPGraph
|
||||
@@ -601,9 +600,6 @@ def prepare_op_graph_for_mlir(op_graph: OPGraph):
|
||||
# Update bit_width for MLIR
|
||||
update_bit_width_for_mlir(op_graph)
|
||||
|
||||
# TODO: workaround extend LUT #359
|
||||
extend_direct_lookup_tables(op_graph)
|
||||
|
||||
# HACK
|
||||
# TODO: remove this ugly hack when https://github.com/zama-ai/concretefhe-internal/issues/1001
|
||||
# is done
|
||||
|
||||
@@ -50,6 +50,42 @@ def identity_lut_generator(n):
|
||||
return lambda x: LookupTable(list(range(2 ** n)))[x]
|
||||
|
||||
|
||||
def negative_identity_smaller_lut_generator(n):
|
||||
"""Test negative lookup table"""
|
||||
|
||||
table = LookupTable(range(2 ** (n - 1)))
|
||||
offset = 2 ** (n - 1)
|
||||
|
||||
return (lambda x: table[x + (-offset)]), table
|
||||
|
||||
|
||||
def negative_identity_lut_generator(n):
|
||||
"""Test negative lookup table (bigger than bit-width)"""
|
||||
|
||||
table = LookupTable(range(2 ** n))
|
||||
offset = 2 ** (n - 1)
|
||||
|
||||
return (lambda x: table[x + (-offset)]), table
|
||||
|
||||
|
||||
def negative_identity_bigger_lut_generator(n):
|
||||
"""Test negative lookup table (bigger than bit-width)"""
|
||||
|
||||
table = LookupTable(range(2 ** (n + 1)))
|
||||
offset = 2 ** (n - 1)
|
||||
|
||||
return (lambda x: table[x + (-offset)]), table
|
||||
|
||||
|
||||
def weird_lut(n):
|
||||
"""A weird lookup table to test an edge case related to negative indexing"""
|
||||
|
||||
table = LookupTable([0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 6, 7])
|
||||
offset = 2 ** (n - 1)
|
||||
|
||||
return (lambda x: table[x + (-offset)]), table
|
||||
|
||||
|
||||
def random_lut_1b(x):
|
||||
"""1-bit random table lookup"""
|
||||
|
||||
@@ -1375,6 +1411,45 @@ def test_compile_and_run_lut_correctness(
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,table,bit_width",
|
||||
[
|
||||
pytest.param(*negative_identity_smaller_lut_generator(n), n, id=f"smaller ({n}-bit)")
|
||||
for n in range(1, 8)
|
||||
]
|
||||
+ [
|
||||
pytest.param(*negative_identity_lut_generator(n), n, id=f"normal ({n}-bit)")
|
||||
for n in range(1, 8)
|
||||
]
|
||||
+ [
|
||||
pytest.param(*negative_identity_bigger_lut_generator(n), n, id=f"bigger ({n}-bit)")
|
||||
for n in range(1, 7)
|
||||
]
|
||||
+ [
|
||||
pytest.param(*weird_lut(3), 3, id="weird"),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_negative_lut_correctness(
|
||||
function,
|
||||
table,
|
||||
bit_width,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness when running a compiled function with LUT using negative values"""
|
||||
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
{"x": EncryptedScalar(UnsignedInteger(bit_width))},
|
||||
range(2 ** bit_width),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
offset = 2 ** (bit_width - 1)
|
||||
for value in range(-offset, offset):
|
||||
assert table[value] == function(value + offset)
|
||||
check_is_good_execution(circuit, function, [value + offset])
|
||||
|
||||
|
||||
def test_compile_and_run_multi_lut_correctness(default_compilation_configuration):
|
||||
"""Test correctness of results when running a compiled function with Multi LUT"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user