diff --git a/hdk/common/debugging/draw_graph.py b/hdk/common/debugging/draw_graph.py index 63455f560..abb8a13aa 100644 --- a/hdk/common/debugging/draw_graph.py +++ b/hdk/common/debugging/draw_graph.py @@ -13,7 +13,8 @@ IR_NODE_COLOR_MAPPING = { ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green", - ir.ArbitraryFunction: "orange", + "ArbitraryFunction": "orange", + "TLU": "grey", "output": "magenta", } @@ -115,6 +116,8 @@ def draw_graph( def get_color(node): if node in set_of_nodes_which_are_outputs: return IR_NODE_COLOR_MAPPING["output"] + if isinstance(node, ir.ArbitraryFunction): + return IR_NODE_COLOR_MAPPING[node.op_name] return IR_NODE_COLOR_MAPPING[type(node)] color_map = [get_color(node) for node in graph.nodes()] @@ -127,6 +130,8 @@ def draw_graph( return node.input_name if isinstance(node, ir.ConstantInput): return str(node.constant_data) + if isinstance(node, ir.ArbitraryFunction): + return node.op_name return node.__class__.__name__ label_dict = {node: get_proper_name(node) for node in graph.nodes()} @@ -209,7 +214,7 @@ def draw_graph( plt.show(block=block_until_user_closes_graph) -def data_type_to_string(node): +def output_data_type_to_string(node): """Return the datatypes of the outputs of the node. Args: @@ -249,7 +254,13 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: elif isinstance(node, ir.ConstantInput): what_to_print = f"ConstantInput({node.constant_data})" else: - what_to_print = node.__class__.__name__ + "(" + + base_name = node.__class__.__name__ + + if isinstance(node, ir.ArbitraryFunction): + base_name = node.op_name + + what_to_print = base_name + "(" # Find all the names of the current predecessors of the node list_of_arg_name = [] @@ -273,7 +284,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: # Manage datatypes if show_data_types: - new_line = f"{new_line: <40s} # {data_type_to_string(node)}" + new_line = f"{new_line: <40s} # {output_data_type_to_string(node)}" returned_str += f"\n{new_line}" diff --git a/hdk/common/extensions/table.py b/hdk/common/extensions/table.py index 0ef1e0041..74845799d 100644 --- a/hdk/common/extensions/table.py +++ b/hdk/common/extensions/table.py @@ -40,6 +40,7 @@ class LookupTable: arbitrary_func=LookupTable._checked_indexing, output_dtype=self.output_dtype, op_kwargs={"table": deepcopy(self.table)}, + op_name="TLU", ) return key.__class__( inputs=[key], diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py index 2ad4e1bc4..0c98d8eb6 100644 --- a/hdk/common/representation/intermediate.py +++ b/hdk/common/representation/intermediate.py @@ -186,12 +186,14 @@ class ArbitraryFunction(IntermediateNode): arbitrary_func: Optional[Callable] op_args: Tuple[Any, ...] op_kwargs: Dict[str, Any] + op_name: str def __init__( self, input_base_value: BaseValue, arbitrary_func: Callable, output_dtype: BaseDataType, + op_name: Optional[str] = None, op_args: Optional[Tuple[Any, ...]] = None, op_kwargs: Optional[Dict[str, Any]] = None, ) -> None: @@ -202,6 +204,7 @@ class ArbitraryFunction(IntermediateNode): self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {} # TLU/PBS has an encrypted output self.outputs = [EncryptedValue(output_dtype)] + self.op_name = op_name if op_name is not None else self.__class__.__name__ def evaluate(self, inputs: Mapping[int, Any]) -> Any: # This is the continuation of the mypy bug workaround @@ -215,5 +218,6 @@ class ArbitraryFunction(IntermediateNode): isinstance(other, ArbitraryFunction) and self.op_args == other.op_args and self.op_kwargs == other.op_kwargs + and self.op_name == other.op_name and super().is_equivalent_to(other) ) diff --git a/tests/common/extensions/test_table.py b/tests/common/extensions/test_table.py index 530263871..656426dc9 100644 --- a/tests/common/extensions/test_table.py +++ b/tests/common/extensions/test_table.py @@ -57,6 +57,7 @@ def test_lookup_table_encrypted_lookup(test_helpers): arbitrary_func=lambda x, table: table[x], output_dtype=table.output_dtype, op_kwargs={"table": deepcopy(table.table)}, + op_name="TLU", ) ref_graph.add_node(output_arbitrary_function, content=output_arbitrary_function) @@ -93,6 +94,7 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers): arbitrary_func=lambda x, table: table[x], output_dtype=table.output_dtype, op_kwargs={"table": deepcopy(table.table)}, + op_name="TLU", ) ref_graph.add_node(intermediate_arbitrary_function, content=intermediate_arbitrary_function) diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index 9e5972a1b..080ad56e3 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -153,16 +153,12 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y): ( lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], {"x": EncryptedValue(Integer(2, is_signed=False))}, - "\n%0 = x\n%1 = ArbitraryFunction(0)\nreturn(%1)", + "\n%0 = x\n%1 = TLU(0)\nreturn(%1)", ), ( lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4], {"x": EncryptedValue(Integer(2, is_signed=False))}, - "\n%0 = x" - "\n%1 = ConstantInput(4)" - "\n%2 = Add(0, 1)" - "\n%3 = ArbitraryFunction(2)" - "\nreturn(%3)", + "\n%0 = x\n%1 = ConstantInput(4)\n%2 = Add(0, 1)\n%3 = TLU(2)\nreturn(%3)", ), ], ) @@ -230,7 +226,7 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str): lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], {"x": EncryptedValue(Integer(2, is_signed=False))}, "\n%0 = x # Integer" - "\n%1 = ArbitraryFunction(0) # Integer" + "\n%1 = TLU(0) # Integer" "\nreturn(%1)", ), ( @@ -239,7 +235,7 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str): "\n%0 = x # Integer" "\n%1 = ConstantInput(4) # Integer" "\n%2 = Add(0, 1) # Integer" - "\n%3 = ArbitraryFunction(2) # Integer" + "\n%3 = TLU(2) # Integer" "\nreturn(%3)", ), ( @@ -248,8 +244,8 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str): "\n%0 = x # Integer" "\n%1 = ConstantInput(4) # Integer" "\n%2 = Add(0, 1) # Integer" - "\n%3 = ArbitraryFunction(2) # Integer" - "\n%4 = ArbitraryFunction(3) # Integer" + "\n%3 = TLU(2) # Integer" + "\n%4 = TLU(3) # Integer" "\nreturn(%4)", ), ],