mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
Feat/user friendly arbitrary function name 144 (#149)
* feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 * feat: let the dev give a useful name for ArbitraryFunction might be useful to debug or understand what happens closes #144 Co-authored-by: Benoit Chevallier-Mames <benoitchevalliermames@zama.ai>
This commit is contained in:
committed by
GitHub
parent
5961d1630e
commit
3245d3e673
@@ -13,7 +13,8 @@ IR_NODE_COLOR_MAPPING = {
|
||||
ir.Add: "red",
|
||||
ir.Sub: "yellow",
|
||||
ir.Mul: "green",
|
||||
ir.ArbitraryFunction: "orange",
|
||||
"ArbitraryFunction": "orange",
|
||||
"TLU": "grey",
|
||||
"output": "magenta",
|
||||
}
|
||||
|
||||
@@ -115,6 +116,8 @@ def draw_graph(
|
||||
def get_color(node):
|
||||
if node in set_of_nodes_which_are_outputs:
|
||||
return IR_NODE_COLOR_MAPPING["output"]
|
||||
if isinstance(node, ir.ArbitraryFunction):
|
||||
return IR_NODE_COLOR_MAPPING[node.op_name]
|
||||
return IR_NODE_COLOR_MAPPING[type(node)]
|
||||
|
||||
color_map = [get_color(node) for node in graph.nodes()]
|
||||
@@ -127,6 +130,8 @@ def draw_graph(
|
||||
return node.input_name
|
||||
if isinstance(node, ir.ConstantInput):
|
||||
return str(node.constant_data)
|
||||
if isinstance(node, ir.ArbitraryFunction):
|
||||
return node.op_name
|
||||
return node.__class__.__name__
|
||||
|
||||
label_dict = {node: get_proper_name(node) for node in graph.nodes()}
|
||||
@@ -209,7 +214,7 @@ def draw_graph(
|
||||
plt.show(block=block_until_user_closes_graph)
|
||||
|
||||
|
||||
def data_type_to_string(node):
|
||||
def output_data_type_to_string(node):
|
||||
"""Return the datatypes of the outputs of the node.
|
||||
|
||||
Args:
|
||||
@@ -249,7 +254,13 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
elif isinstance(node, ir.ConstantInput):
|
||||
what_to_print = f"ConstantInput({node.constant_data})"
|
||||
else:
|
||||
what_to_print = node.__class__.__name__ + "("
|
||||
|
||||
base_name = node.__class__.__name__
|
||||
|
||||
if isinstance(node, ir.ArbitraryFunction):
|
||||
base_name = node.op_name
|
||||
|
||||
what_to_print = base_name + "("
|
||||
|
||||
# Find all the names of the current predecessors of the node
|
||||
list_of_arg_name = []
|
||||
@@ -273,7 +284,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
|
||||
# Manage datatypes
|
||||
if show_data_types:
|
||||
new_line = f"{new_line: <40s} # {data_type_to_string(node)}"
|
||||
new_line = f"{new_line: <40s} # {output_data_type_to_string(node)}"
|
||||
|
||||
returned_str += f"\n{new_line}"
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ class LookupTable:
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_dtype=self.output_dtype,
|
||||
op_kwargs={"table": deepcopy(self.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
return key.__class__(
|
||||
inputs=[key],
|
||||
|
||||
@@ -186,12 +186,14 @@ class ArbitraryFunction(IntermediateNode):
|
||||
arbitrary_func: Optional[Callable]
|
||||
op_args: Tuple[Any, ...]
|
||||
op_kwargs: Dict[str, Any]
|
||||
op_name: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_base_value: BaseValue,
|
||||
arbitrary_func: Callable,
|
||||
output_dtype: BaseDataType,
|
||||
op_name: Optional[str] = None,
|
||||
op_args: Optional[Tuple[Any, ...]] = None,
|
||||
op_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
@@ -202,6 +204,7 @@ class ArbitraryFunction(IntermediateNode):
|
||||
self.op_kwargs = deepcopy(op_kwargs) if op_kwargs is not None else {}
|
||||
# TLU/PBS has an encrypted output
|
||||
self.outputs = [EncryptedValue(output_dtype)]
|
||||
self.op_name = op_name if op_name is not None else self.__class__.__name__
|
||||
|
||||
def evaluate(self, inputs: Mapping[int, Any]) -> Any:
|
||||
# This is the continuation of the mypy bug workaround
|
||||
@@ -215,5 +218,6 @@ class ArbitraryFunction(IntermediateNode):
|
||||
isinstance(other, ArbitraryFunction)
|
||||
and self.op_args == other.op_args
|
||||
and self.op_kwargs == other.op_kwargs
|
||||
and self.op_name == other.op_name
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
@@ -57,6 +57,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
|
||||
arbitrary_func=lambda x, table: table[x],
|
||||
output_dtype=table.output_dtype,
|
||||
op_kwargs={"table": deepcopy(table.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
ref_graph.add_node(output_arbitrary_function, content=output_arbitrary_function)
|
||||
|
||||
@@ -93,6 +94,7 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
|
||||
arbitrary_func=lambda x, table: table[x],
|
||||
output_dtype=table.output_dtype,
|
||||
op_kwargs={"table": deepcopy(table.table)},
|
||||
op_name="TLU",
|
||||
)
|
||||
ref_graph.add_node(intermediate_arbitrary_function, content=intermediate_arbitrary_function)
|
||||
|
||||
|
||||
@@ -153,16 +153,12 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
"\n%0 = x\n%1 = ArbitraryFunction(0)\nreturn(%1)",
|
||||
"\n%0 = x\n%1 = TLU(0)\nreturn(%1)",
|
||||
),
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
"\n%0 = x"
|
||||
"\n%1 = ConstantInput(4)"
|
||||
"\n%2 = Add(0, 1)"
|
||||
"\n%3 = ArbitraryFunction(2)"
|
||||
"\nreturn(%3)",
|
||||
"\n%0 = x\n%1 = ConstantInput(4)\n%2 = Add(0, 1)\n%3 = TLU(2)\nreturn(%3)",
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -230,7 +226,7 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
"\n%0 = x # Integer<unsigned, 2 bits>"
|
||||
"\n%1 = ArbitraryFunction(0) # Integer<unsigned, 4 bits>"
|
||||
"\n%1 = TLU(0) # Integer<unsigned, 4 bits>"
|
||||
"\nreturn(%1)",
|
||||
),
|
||||
(
|
||||
@@ -239,7 +235,7 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
"\n%0 = x # Integer<unsigned, 2 bits>"
|
||||
"\n%1 = ConstantInput(4) # Integer<unsigned, 3 bits>"
|
||||
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
|
||||
"\n%3 = ArbitraryFunction(2) # Integer<unsigned, 2 bits>"
|
||||
"\n%3 = TLU(2) # Integer<unsigned, 2 bits>"
|
||||
"\nreturn(%3)",
|
||||
),
|
||||
(
|
||||
@@ -248,8 +244,8 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
"\n%0 = x # Integer<unsigned, 2 bits>"
|
||||
"\n%1 = ConstantInput(4) # Integer<unsigned, 3 bits>"
|
||||
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
|
||||
"\n%3 = ArbitraryFunction(2) # Integer<unsigned, 2 bits>"
|
||||
"\n%4 = ArbitraryFunction(3) # Integer<unsigned, 4 bits>"
|
||||
"\n%3 = TLU(2) # Integer<unsigned, 2 bits>"
|
||||
"\n%4 = TLU(3) # Integer<unsigned, 4 bits>"
|
||||
"\nreturn(%4)",
|
||||
),
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user