feat(extensions): add and test support for negative direct table lookups

This commit is contained in:
Umut
2021-11-30 16:10:59 +03:00
parent 8b0a793cda
commit 1b5785e058
4 changed files with 78 additions and 29 deletions

View File

@@ -62,7 +62,7 @@ class LookupTable:
@staticmethod
def _check_index_out_of_range(x, table):
if x < 0 or x >= len(table):
if not -len(table) <= x < len(table):
raise ValueError(
f"Lookup table with {len(table)} entries cannot be indexed with {x} "
f"(you should check your inputset)",

View File

@@ -1,9 +1,8 @@
"""Utilities for MLIR conversion."""
from typing import Dict, List, Optional, cast
from typing import Dict, List, Optional
import networkx as nx
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_scalar_integer,
value_is_clear_tensor_integer,
@@ -16,7 +15,7 @@ from ..debugging import format_operation_graph
from ..debugging.custom_assert import assert_not_reached, assert_true
from ..operator_graph import OPGraph
from ..representation import intermediate
from ..representation.intermediate import GenericFunction, IntermediateNode
from ..representation.intermediate import IntermediateNode
# TODO: should come from compiler, through an API, #402
ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7
@@ -210,24 +209,3 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
)
_set_all_bit_width(op_graph, max_bit_width)
def extend_direct_lookup_tables(op_graph: OPGraph):
"""Extend direct lookup tables to the maximum length the input bit width can support.
Args:
op_graph: graph to update lookup tables for
"""
for node in op_graph.graph.nodes:
if isinstance(node, GenericFunction) and node.op_name == "TLU":
table = node.op_kwargs["table"]
bit_width = cast(Integer, node.inputs[0].dtype).bit_width
expected_length = 2 ** bit_width
# TODO: remove no cover once the table length workaround is removed
# (https://github.com/zama-ai/concretefhe-internal/issues/359)
if len(table) > expected_length: # pragma: no cover
node.op_kwargs["table"] = table[:expected_length]
else:
repeat = expected_length // len(table)
node.op_kwargs["table"] = (table * repeat)[:expected_length]

View File

@@ -17,7 +17,6 @@ from ..common.debugging.custom_assert import assert_true
from ..common.fhe_circuit import FHECircuit
from ..common.mlir.utils import (
check_graph_values_compatibility_with_mlir,
extend_direct_lookup_tables,
update_bit_width_for_mlir,
)
from ..common.operator_graph import OPGraph
@@ -601,9 +600,6 @@ def prepare_op_graph_for_mlir(op_graph: OPGraph):
# Update bit_width for MLIR
update_bit_width_for_mlir(op_graph)
# TODO: workaround extend LUT #359
extend_direct_lookup_tables(op_graph)
# HACK
# TODO: remove this ugly hack when https://github.com/zama-ai/concretefhe-internal/issues/1001
# is done

View File

@@ -50,6 +50,42 @@ def identity_lut_generator(n):
return lambda x: LookupTable(list(range(2 ** n)))[x]
def negative_identity_smaller_lut_generator(n):
"""Test negative lookup table"""
table = LookupTable(range(2 ** (n - 1)))
offset = 2 ** (n - 1)
return (lambda x: table[x + (-offset)]), table
def negative_identity_lut_generator(n):
"""Test negative lookup table (bigger than bit-width)"""
table = LookupTable(range(2 ** n))
offset = 2 ** (n - 1)
return (lambda x: table[x + (-offset)]), table
def negative_identity_bigger_lut_generator(n):
"""Test negative lookup table (bigger than bit-width)"""
table = LookupTable(range(2 ** (n + 1)))
offset = 2 ** (n - 1)
return (lambda x: table[x + (-offset)]), table
def weird_lut(n):
"""A weird lookup table to test an edge case related to negative indexing"""
table = LookupTable([0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 6, 7])
offset = 2 ** (n - 1)
return (lambda x: table[x + (-offset)]), table
def random_lut_1b(x):
"""1-bit random table lookup"""
@@ -1375,6 +1411,45 @@ def test_compile_and_run_lut_correctness(
check_is_good_execution(compiler_engine, function, args)
@pytest.mark.parametrize(
"function,table,bit_width",
[
pytest.param(*negative_identity_smaller_lut_generator(n), n, id=f"smaller ({n}-bit)")
for n in range(1, 8)
]
+ [
pytest.param(*negative_identity_lut_generator(n), n, id=f"normal ({n}-bit)")
for n in range(1, 8)
]
+ [
pytest.param(*negative_identity_bigger_lut_generator(n), n, id=f"bigger ({n}-bit)")
for n in range(1, 7)
]
+ [
pytest.param(*weird_lut(3), 3, id="weird"),
],
)
def test_compile_and_run_negative_lut_correctness(
function,
table,
bit_width,
default_compilation_configuration,
):
"""Test correctness when running a compiled function with LUT using negative values"""
circuit = compile_numpy_function(
function,
{"x": EncryptedScalar(UnsignedInteger(bit_width))},
range(2 ** bit_width),
default_compilation_configuration,
)
offset = 2 ** (bit_width - 1)
for value in range(-offset, offset):
assert table[value] == function(value + offset)
check_is_good_execution(circuit, function, [value + offset])
def test_compile_and_run_multi_lut_correctness(default_compilation_configuration):
"""Test correctness of results when running a compiled function with Multi LUT"""