mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
* feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 Co-authored-by: Benoit Chevallier-Mames <benoitchevalliermames@zama.ai>
65 lines
2.2 KiB
Python
65 lines
2.2 KiB
Python
"""This file contains a wrapper class for direct table lookups."""
|
|
|
|
from copy import deepcopy
|
|
from typing import Iterable, Tuple, Union
|
|
|
|
from ..common_helpers import is_a_power_of_2
|
|
from ..data_types.base import BaseDataType
|
|
from ..data_types.integers import make_integer_to_hold_ints
|
|
from ..representation import intermediate as ir
|
|
from ..tracing.base_tracer import BaseTracer
|
|
|
|
|
|
class LookupTable:
|
|
"""Class representing a lookup table."""
|
|
|
|
# lookup table itself, has 2^N entries
|
|
table: Tuple[int, ...]
|
|
|
|
# type of the result of the lookup
|
|
output_dtype: BaseDataType
|
|
|
|
def __init__(self, table: Iterable[int]):
|
|
table = tuple(table)
|
|
|
|
if not is_a_power_of_2(len(table)):
|
|
raise ValueError(
|
|
f"Desired lookup table has inappropriate number of entries ({len(table)})"
|
|
)
|
|
|
|
self.table = table
|
|
self.output_dtype = make_integer_to_hold_ints(table, force_signed=False)
|
|
|
|
def __getitem__(self, key: Union[int, BaseTracer]):
|
|
# if a tracer is used for indexing,
|
|
# we need to create an `ArbitraryFunction` node
|
|
# because the result will be determined during the runtime
|
|
if isinstance(key, BaseTracer):
|
|
traced_computation = ir.ArbitraryFunction(
|
|
input_base_value=key.output,
|
|
arbitrary_func=LookupTable._checked_indexing,
|
|
output_dtype=self.output_dtype,
|
|
op_kwargs={"table": deepcopy(self.table)},
|
|
op_name="TLU",
|
|
)
|
|
return key.__class__(
|
|
inputs=[key],
|
|
traced_computation=traced_computation,
|
|
output_index=0,
|
|
)
|
|
|
|
# if not, it means table is indexed with a constant
|
|
# thus, the result of the lookup is a constant
|
|
# so, we can propagate it directly
|
|
return LookupTable._checked_indexing(key, self.table)
|
|
|
|
@staticmethod
|
|
def _checked_indexing(x, table):
|
|
if x < 0 or x >= len(table):
|
|
raise ValueError(
|
|
f"Lookup table with {len(table)} entries cannot be indexed with {x} "
|
|
f"(you should check your dataset)",
|
|
)
|
|
|
|
return table[x]
|