From dbda93639b86eabaf9c83bc112121ca80ab4e9df Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 30 Aug 2021 16:16:35 +0200 Subject: [PATCH] dev(ir): add get_table function to ArbitraryFunction node --- hdk/common/representation/intermediate.py | 25 +++++++++++++++++++++++ tests/common/extensions/test_table.py | 2 ++ 2 files changed, 27 insertions(+) diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index bbe6530c5..d9114c4a5 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -9,6 +9,7 @@ from ..data_types.dtypes_helpers import ( get_base_value_for_python_constant_data, mix_scalar_values_determine_holding_dtype, ) +from ..data_types.integers import Integer from ..values import BaseValue, ClearValue, EncryptedValue, TensorValue IR_MIX_VALUES_FUNC_ARG_NAME = "mix_values_func" @@ -228,6 +229,30 @@ class ArbitraryFunction(IntermediateNode): def label(self) -> str: return self.op_name + def get_table(self) -> List[int]: + """Function to get the table for the current input value of this ArbitraryFunction. + + Returns: + List[int]: The table. + """ + # Check the input is an unsigned integer to be able to build a table + assert isinstance( + self.inputs[0].data_type, Integer + ), "get_table only works for an unsigned Integer input" + assert not self.inputs[ + 0 + ].data_type.is_signed, "get_table only works for an unsigned Integer input" + + min_input_range = self.inputs[0].data_type.min_value() + max_input_range = self.inputs[0].data_type.max_value() + 1 + + table = [ + int(self.evaluate({0: input_value})) + for input_value in range(min_input_range, max_input_range) + ] + + return table + def default_dot_evaluation_function(lhs: Any, rhs: Any) -> Any: """Default python dot implementation for 1D iterable arrays. diff --git a/tests/common/extensions/test_table.py b/tests/common/extensions/test_table.py index f936367d1..7de22abe8 100644 --- a/tests/common/extensions/test_table.py +++ b/tests/common/extensions/test_table.py @@ -45,6 +45,8 @@ def test_lookup_table_encrypted_lookup(test_helpers): x = EncryptedValue(Integer(2, is_signed=False)) op_graph = tracing.trace_numpy_function(f, {"x": x}) + assert op_graph.output_nodes[0].get_table() == [3, 6, 0, 2] + ref_graph = nx.MultiDiGraph() # Here is the ASCII drawing of the expected graph: # (x) - (TLU)