dev(ir): add get_table function to ArbitraryFunction node

This commit is contained in:
Arthur Meyre
2021-08-30 16:16:35 +02:00
committed by Ayoub Benaissa
parent 74f0c9600e
commit dbda93639b
2 changed files with 27 additions and 0 deletions

View File

@@ -9,6 +9,7 @@ from ..data_types.dtypes_helpers import (
get_base_value_for_python_constant_data,
mix_scalar_values_determine_holding_dtype,
)
from ..data_types.integers import Integer
from ..values import BaseValue, ClearValue, EncryptedValue, TensorValue
IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func"
@@ -228,6 +229,30 @@ class ArbitraryFunction(IntermediateNode):
def label(self) -> str:
return self.op_name
def get_table(self) -> List[int]:
"""Function to get the table for the current input value of this ArbitraryFunction.
Returns:
List[int]: The table.
"""
# Check the input is an unsigned integer to be able to build a table
assert isinstance(
self.inputs[0].data_type, Integer
), "get_table only works for an unsigned Integer input"
assert not self.inputs[
0
].data_type.is_signed, "get_table only works for an unsigned Integer input"
min_input_range = self.inputs[0].data_type.min_value()
max_input_range = self.inputs[0].data_type.max_value() + 1
table = [
int(self.evaluate({0: input_value}))
for input_value in range(min_input_range, max_input_range)
]
return table
def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any:
"""Default python dot implementation for 1D iterable arrays.

View File

@@ -45,6 +45,8 @@ def test_lookup_table_encrypted_lookup(test_helpers):
x = EncryptedValue(Integer(2, is_signed=False))
op_graph = tracing.trace_numpy_function(f, {"x": x})
assert op_graph.output_nodes[0].get_table() == [3, 6, 0, 2]
ref_graph = nx.MultiDiGraph()
# Here is the ASCII drawing of the expected graph:
# (x) - (TLU)