mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
dev(ir): add get_table function to ArbitraryFunction node
This commit is contained in:
committed by
Ayoub Benaissa
parent
74f0c9600e
commit
dbda93639b
@@ -9,6 +9,7 @@ from ..data_types.dtypes_helpers import (
|
||||
get_base_value_for_python_constant_data,
|
||||
mix_scalar_values_determine_holding_dtype,
|
||||
)
|
||||
from ..data_types.integers import Integer
|
||||
from ..values import BaseValue, ClearValue, EncryptedValue, TensorValue
|
||||
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
|
||||
@@ -228,6 +229,30 @@ class ArbitraryFunction(IntermediateNode):
|
||||
def label(self) -> str:
|
||||
return self.op_name
|
||||
|
||||
def get_table(self) -> List[int]:
|
||||
"""Function to get the table for the current input value of this ArbitraryFunction.
|
||||
|
||||
Returns:
|
||||
List[int]: The table.
|
||||
"""
|
||||
# Check the input is an unsigned integer to be able to build a table
|
||||
assert isinstance(
|
||||
self.inputs[0].data_type, Integer
|
||||
), "get_table only works for an unsigned Integer input"
|
||||
assert not self.inputs[
|
||||
0
|
||||
].data_type.is_signed, "get_table only works for an unsigned Integer input"
|
||||
|
||||
min_input_range = self.inputs[0].data_type.min_value()
|
||||
max_input_range = self.inputs[0].data_type.max_value() + 1
|
||||
|
||||
table = [
|
||||
int(self.evaluate({0: input_value}))
|
||||
for input_value in range(min_input_range, max_input_range)
|
||||
]
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any:
|
||||
"""Default python dot implementation for 1D iterable arrays.
|
||||
|
||||
@@ -45,6 +45,8 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
x = EncryptedValue(Integer(2, is_signed=False))
|
||||
op_graph = tracing.trace_numpy_function(f, {"x": x})
|
||||
|
||||
assert op_graph.output_nodes[0].get_table() == [3, 6, 0, 2]
|
||||
|
||||
ref_graph = nx.MultiDiGraph()
|
||||
# Here is the ASCII drawing of the expected graph:
|
||||
# (x) - (TLU)
|
||||
|
||||
Reference in New Issue
Block a user