mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: get_printable_graph is more precise
get_printable_graph prints the constant if the ArbitraryFunction has a baked constant closes #584
This commit is contained in:
committed by
Benoit Chevallier
parent
286dda79b2
commit
dedbde93d0
@@ -22,6 +22,23 @@ def output_data_type_to_string(node):
|
||||
return ", ".join([str(o) for o in node.outputs])
|
||||
|
||||
|
||||
def shorten_a_constant(constant_data: str):
|
||||
"""Return a constant (if small) or an extra of the constant (if too large).
|
||||
|
||||
Args:
|
||||
constant (str): The constant we want to shorten
|
||||
|
||||
Returns:
|
||||
str: a string to represent the constant
|
||||
"""
|
||||
|
||||
content = str(constant_data).replace("\n", "")
|
||||
# if content is longer than 25 chars, only show the first and the last 10 chars of it
|
||||
# 25 is selected using the spaces available before data type information
|
||||
short_content = f"{content[:10]} ... {content[-10:]}" if len(content) > 25 else content
|
||||
return short_content
|
||||
|
||||
|
||||
def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
"""Return a string representing a graph.
|
||||
|
||||
@@ -52,10 +69,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
if isinstance(node, Input):
|
||||
what_to_print = node.input_name
|
||||
elif isinstance(node, Constant):
|
||||
content = str(node.constant_data).replace("\n", "")
|
||||
# if content is longer than 25 chars, only show the first and the last 10 chars of it
|
||||
# 25 is selected using the spaces available before data type information
|
||||
to_show = f"{content[:10]} ... {content[-10:]}" if len(content) > 25 else content
|
||||
to_show = shorten_a_constant(node.constant_data)
|
||||
what_to_print = f"Constant({to_show})"
|
||||
else:
|
||||
|
||||
@@ -81,15 +95,34 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
list_of_arg_name.sort()
|
||||
custom_assert([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name))))
|
||||
|
||||
prefix_to_add_to_what_to_print = ""
|
||||
suffix_to_add_to_what_to_print = ""
|
||||
|
||||
# Print constant that may be in the UnivariateFunction. For the moment, it considers
|
||||
# there is a single constant maximally and that there is 2 inputs maximally
|
||||
if isinstance(node, UnivariateFunction) and "baked_constant" in node.op_kwargs:
|
||||
baked_constant = node.op_kwargs["baked_constant"]
|
||||
if node.op_attributes["in_which_input_is_constant"] == 0:
|
||||
prefix_to_add_to_what_to_print = f"{shorten_a_constant(baked_constant)}, "
|
||||
else:
|
||||
custom_assert(
|
||||
node.op_attributes["in_which_input_is_constant"] == 1,
|
||||
"'in_which_input_is_constant' should be a key of node.op_attributes",
|
||||
)
|
||||
suffix_to_add_to_what_to_print = f", {shorten_a_constant(baked_constant)}"
|
||||
|
||||
# Then, just print the predecessors in the right order
|
||||
what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name]) + ")"
|
||||
what_to_print += prefix_to_add_to_what_to_print
|
||||
what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name])
|
||||
what_to_print += suffix_to_add_to_what_to_print
|
||||
what_to_print += ")"
|
||||
|
||||
# This code doesn't work with more than a single output
|
||||
new_line = f"%{i} = {what_to_print}"
|
||||
|
||||
# Manage datatypes
|
||||
if show_data_types:
|
||||
new_line = f"{new_line: <40s} # {output_data_type_to_string(node)}"
|
||||
new_line = f"{new_line: <50s} # {output_data_type_to_string(node)}"
|
||||
|
||||
returned_str += f"{new_line}\n"
|
||||
|
||||
|
||||
@@ -360,11 +360,11 @@ def test_small_inputset_treat_warnings_as_errors():
|
||||
(4,),
|
||||
# Remark that, when you do the dot of tensors of 4 values between 0 and 3,
|
||||
# you can get a maximal value of 4*3*3 = 36, ie something on 6 bits
|
||||
"%0 = x "
|
||||
"%0 = x "
|
||||
"# EncryptedTensor<Integer<unsigned, 6 bits>, shape=(4,)>"
|
||||
"\n%1 = y "
|
||||
"\n%1 = y "
|
||||
"# EncryptedTensor<Integer<unsigned, 6 bits>, shape=(4,)>"
|
||||
"\n%2 = Dot(%0, %1) "
|
||||
"\n%2 = Dot(%0, %1) "
|
||||
"# EncryptedScalar<Integer<unsigned, 6 bits>>"
|
||||
"\nreturn(%2)\n",
|
||||
),
|
||||
|
||||
@@ -112,6 +112,14 @@ def issue_130_c(x, y):
|
||||
issue_130_c,
|
||||
"%0 = Constant(1)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2, %2)\n",
|
||||
),
|
||||
(
|
||||
lambda x, y: numpy.arctan2(x, 42) + y,
|
||||
"%0 = y\n%1 = x\n%2 = np.arctan2(%1, 42)\n%3 = Add(%2, %0)\nreturn(%3)\n",
|
||||
),
|
||||
(
|
||||
lambda x, y: numpy.arctan2(43, x) + y,
|
||||
"%0 = y\n%1 = x\n%2 = np.arctan2(43, %1)\n%3 = Add(%2, %0)\nreturn(%3)\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -221,9 +229,12 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
EncryptedScalar(Integer(64, is_signed=False)),
|
||||
EncryptedScalar(Integer(32, is_signed=True)),
|
||||
),
|
||||
"%0 = x # EncryptedScalar<Integer<unsigned, 64 bits>>"
|
||||
"\n%1 = y # EncryptedScalar<Integer<signed, 32 bits>>"
|
||||
"\n%2 = Add(%0, %1) # EncryptedScalar<Integer<signed, 65 bits>>"
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 64 bits>>"
|
||||
"\n%1 = y "
|
||||
" # EncryptedScalar<Integer<signed, 32 bits>>"
|
||||
"\n%2 = Add(%0, %1) "
|
||||
" # EncryptedScalar<Integer<signed, 65 bits>>"
|
||||
"\nreturn(%2)\n",
|
||||
),
|
||||
(
|
||||
@@ -232,11 +243,11 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
EncryptedScalar(Integer(17, is_signed=False)),
|
||||
EncryptedScalar(Integer(23, is_signed=False)),
|
||||
),
|
||||
"%0 = x "
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 17 bits>>"
|
||||
"\n%1 = y "
|
||||
"\n%1 = y "
|
||||
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
|
||||
"\n%2 = Mul(%0, %1) "
|
||||
"\n%2 = Mul(%0, %1) "
|
||||
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
|
||||
"\nreturn(%2)\n",
|
||||
),
|
||||
@@ -262,37 +273,37 @@ def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
"%0 = x "
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\n%1 = TLU(%0) "
|
||||
"\n%1 = TLU(%0) "
|
||||
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
|
||||
"\nreturn(%1)\n",
|
||||
),
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
"%0 = x "
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\n%1 = Constant(4) "
|
||||
"\n%1 = Constant(4) "
|
||||
"# ClearScalar<Integer<unsigned, 3 bits>>"
|
||||
"\n%2 = Add(%0, %1) "
|
||||
"\n%2 = Add(%0, %1) "
|
||||
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
|
||||
"\n%3 = TLU(%2) "
|
||||
"\n%3 = TLU(%2) "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\nreturn(%3)\n",
|
||||
),
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]],
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
"%0 = x "
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\n%1 = Constant(4) "
|
||||
"\n%1 = Constant(4) "
|
||||
"# ClearScalar<Integer<unsigned, 3 bits>>"
|
||||
"\n%2 = Add(%0, %1) "
|
||||
"\n%2 = Add(%0, %1) "
|
||||
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
|
||||
"\n%3 = TLU(%2) "
|
||||
"\n%3 = TLU(%2) "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\n%4 = TLU(%3) "
|
||||
"\n%4 = TLU(%3) "
|
||||
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
|
||||
"\nreturn(%4)\n",
|
||||
),
|
||||
@@ -311,3 +322,31 @@ def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
f"==================\n"
|
||||
)
|
||||
|
||||
|
||||
def test_numpy_long_constant():
|
||||
"Test get_printable_graph with long constant"
|
||||
|
||||
def all_explicit_operations(x):
|
||||
intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10))
|
||||
intermediate = numpy.subtract(intermediate, numpy.arange(10).reshape(1, 10))
|
||||
intermediate = numpy.arctan2(numpy.arange(10, 20).reshape(1, 10), intermediate)
|
||||
intermediate = numpy.arctan2(numpy.arange(100, 200).reshape(10, 10), intermediate)
|
||||
return intermediate
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(10, 10))}
|
||||
)
|
||||
|
||||
expected = """
|
||||
%0 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
|
||||
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%2 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
|
||||
%3 = Add(%1, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%4 = Sub(%3, %0) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%5 = np.arctan2([[10 11 12 ... 17 18 19]], %4) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
%6 = np.arctan2([[100 101 ... 198 199]], %5) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
return(%6)
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
@@ -188,21 +188,21 @@ def test_numpy_tracing_tensors():
|
||||
)
|
||||
|
||||
expected = """
|
||||
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
|
||||
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
|
||||
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(%5, %6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(%4, %7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(%3, %8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(%9, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(%10, %1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(%0, %11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
|
||||
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
|
||||
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(%5, %6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(%4, %7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(%3, %8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(%9, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(%10, %1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(%0, %11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
return(%12)
|
||||
""".lstrip()
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
@@ -227,21 +227,21 @@ def test_numpy_explicit_tracing_tensors():
|
||||
)
|
||||
|
||||
expected = """
|
||||
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
|
||||
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
|
||||
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(%5, %6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(%4, %7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(%3, %8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(%9, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(%10, %1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(%0, %11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
|
||||
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
|
||||
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(%5, %6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(%4, %7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(%3, %8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(%9, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(%10, %1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(%0, %11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
return(%12)
|
||||
""".lstrip()
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
|
||||
Reference in New Issue
Block a user