Feat/user friendly arbitrary function name 144 (#149)

* feat: let the dev give a useful name for ArbitraryFunction

might be useful to debug or understand what happens
closes #144

* feat: let the dev give a useful name for ArbitraryFunction

might be useful to debug or understand what happens
closes #144

* feat: let the dev give a useful name for ArbitraryFunction

might be useful to debug or understand what happens
closes #144

* feat: let the dev give a useful name for ArbitraryFunction

might be useful to debug or understand what happens
closes #144

* feat: let the dev give a useful name for ArbitraryFunction

might be useful to debug or understand what happens
closes #144

* feat: let the dev give a useful name for ArbitraryFunction

might be useful to debug or understand what happens
closes #144

* feat: let the dev give a useful name for ArbitraryFunction

might be useful to debug or understand what happens
closes #144

Co-authored-by: Benoit Chevallier-Mames <benoitchevalliermames@zama.ai>
This commit is contained in:
Benoit Chevallier
2021-08-13 18:10:35 +02:00
committed by GitHub
parent 5961d1630e
commit 3245d3e673
5 changed files with 28 additions and 14 deletions

View File

@@ -13,7 +13,8 @@ IR_NODE_COLOR_MAPPING = {
ir.Add: "red",
ir.Sub: "yellow",
ir.Mul: "green",
ir.ArbitraryFunction: "orange",
"ArbitraryFunction": "orange",
"TLU": "grey",
"output": "magenta",
}
@@ -115,6 +116,8 @@ def draw_graph(
def get_color(node):
if node in set_of_nodes_which_are_outputs:
return IR_NODE_COLOR_MAPPING["output"]
if isinstance(node, ir.ArbitraryFunction):
return IR_NODE_COLOR_MAPPING[node.op_name]
return IR_NODE_COLOR_MAPPING[type(node)]
color_map = [get_color(node) for node in graph.nodes()]
@@ -127,6 +130,8 @@ def draw_graph(
return node.input_name
if isinstance(node, ir.ConstantInput):
return str(node.constant_data)
if isinstance(node, ir.ArbitraryFunction):
return node.op_name
return node.__class__.__name__
label_dict = {node: get_proper_name(node) for node in graph.nodes()}
@@ -209,7 +214,7 @@ def draw_graph(
plt.show(block=block_until_user_closes_graph)
def data_type_to_string(node):
def output_data_type_to_string(node):
"""Return the datatypes of the outputs of the node.
Args:
@@ -249,7 +254,13 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
elif isinstance(node, ir.ConstantInput):
what_to_print = f"ConstantInput({node.constant_data})"
else:
what_to_print = node.__class__.__name__ + "("
base_name = node.__class__.__name__
if isinstance(node, ir.ArbitraryFunction):
base_name = node.op_name
what_to_print = base_name + "("
# Find all the names of the current predecessors of the node
list_of_arg_name = []
@@ -273,7 +284,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
# Manage datatypes
if show_data_types:
new_line = f"{new_line: <40s} # {data_type_to_string(node)}"
new_line = f"{new_line: <40s} # {output_data_type_to_string(node)}"
returned_str += f"\n{new_line}"

View File

@@ -40,6 +40,7 @@ class LookupTable:
arbitrary_func=LookupTable._checked_indexing,
output_dtype=self.output_dtype,
op_kwargs={"table": deepcopy(self.table)},
op_name="TLU",
)
return key.__class__(
inputs=[key],

View File

@@ -186,12 +186,14 @@ class ArbitraryFunction(IntermediateNode):
arbitrary_func: Optional[Callable]
op_args: Tuple[Any, ...]
op_kwargs: Dict[str, Any]
op_name: str
def __init__(
self,
input_base_value: BaseValue,
arbitrary_func: Callable,
output_dtype: BaseDataType,
op_name: Optional[str] = None,
op_args: Optional[Tuple[Any, ...]] = None,
op_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
@@ -202,6 +204,7 @@ class ArbitraryFunction(IntermediateNode):
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
# TLU/PBS has an encrypted output
self.outputs = [EncryptedValue(output_dtype)]
self.op_name = op_name if op_name is not None else self.__class__.__name__
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
@@ -215,5 +218,6 @@ class ArbitraryFunction(IntermediateNode):
isinstance(other, ArbitraryFunction)
and self.op_args == other.op_args
and self.op_kwargs == other.op_kwargs
and self.op_name == other.op_name
and super().is_equivalent_to(other)
)

View File

@@ -57,6 +57,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
arbitrary_func=lambda x, table: table[x],
output_dtype=table.output_dtype,
op_kwargs={"table": deepcopy(table.table)},
op_name="TLU",
)
ref_graph.add_node(output_arbitrary_function, content=output_arbitrary_function)
@@ -93,6 +94,7 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
arbitrary_func=lambda x, table: table[x],
output_dtype=table.output_dtype,
op_kwargs={"table": deepcopy(table.table)},
op_name="TLU",
)
ref_graph.add_node(intermediate_arbitrary_function, content=intermediate_arbitrary_function)

View File

@@ -153,16 +153,12 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x\n%1 = ArbitraryFunction(0)\nreturn(%1)",
"\n%0 = x\n%1 = TLU(0)\nreturn(%1)",
),
(
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x"
"\n%1 = ConstantInput(4)"
"\n%2 = Add(0, 1)"
"\n%3 = ArbitraryFunction(2)"
"\nreturn(%3)",
"\n%0 = x\n%1 = ConstantInput(4)\n%2 = Add(0, 1)\n%3 = TLU(2)\nreturn(%3)",
),
],
)
@@ -230,7 +226,7 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = ArbitraryFunction(0) # Integer<unsigned, 4 bits>"
"\n%1 = TLU(0) # Integer<unsigned, 4 bits>"
"\nreturn(%1)",
),
(
@@ -239,7 +235,7 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
"\n%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = ConstantInput(4) # Integer<unsigned, 3 bits>"
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
"\n%3 = ArbitraryFunction(2) # Integer<unsigned, 2 bits>"
"\n%3 = TLU(2) # Integer<unsigned, 2 bits>"
"\nreturn(%3)",
),
(
@@ -248,8 +244,8 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
"\n%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = ConstantInput(4) # Integer<unsigned, 3 bits>"
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
"\n%3 = ArbitraryFunction(2) # Integer<unsigned, 2 bits>"
"\n%4 = ArbitraryFunction(3) # Integer<unsigned, 4 bits>"
"\n%3 = TLU(2) # Integer<unsigned, 2 bits>"
"\n%4 = TLU(3) # Integer<unsigned, 4 bits>"
"\nreturn(%4)",
),
],