diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index fc367a2c1..0b0444e7f 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -22,6 +22,23 @@ def output_data_type_to_string(node): return ", ".join([str(o) for o in node.outputs]) +def shorten_a_constant(constant_data: str): + """Return a constant (if small) or an extra of the constant (if too large). + + Args: + constant (str): The constant we want to shorten + + Returns: + str: a string to represent the constant + """ + + content = str(constant_data).replace("\n", "") + # if content is longer than 25 chars, only show the first and the last 10 chars of it + # 25 is selected using the spaces available before data type information + short_content = f"{content[:10]} ... {content[-10:]}" if len(content) > 25 else content + return short_content + + def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: """Return a string representing a graph. @@ -52,10 +69,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: if isinstance(node, Input): what_to_print = node.input_name elif isinstance(node, Constant): - content = str(node.constant_data).replace("\n", "") - # if content is longer than 25 chars, only show the first and the last 10 chars of it - # 25 is selected using the spaces available before data type information - to_show = f"{content[:10]} ... {content[-10:]}" if len(content) > 25 else content + to_show = shorten_a_constant(node.constant_data) what_to_print = f"Constant({to_show})" else: @@ -81,15 +95,34 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: list_of_arg_name.sort() custom_assert([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))) + prefix_to_add_to_what_to_print = "" + suffix_to_add_to_what_to_print = "" + + # Print constant that may be in the UnivariateFunction. For the moment, it considers + # there is a single constant maximally and that there is 2 inputs maximally + if isinstance(node, UnivariateFunction) and "baked_constant" in node.op_kwargs: + baked_constant = node.op_kwargs["baked_constant"] + if node.op_attributes["in_which_input_is_constant"] == 0: + prefix_to_add_to_what_to_print = f"{shorten_a_constant(baked_constant)}, " + else: + custom_assert( + node.op_attributes["in_which_input_is_constant"] == 1, + "'in_which_input_is_constant' should be a key of node.op_attributes", + ) + suffix_to_add_to_what_to_print = f", {shorten_a_constant(baked_constant)}" + # Then, just print the predecessors in the right order - what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name]) + ")" + what_to_print += prefix_to_add_to_what_to_print + what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name]) + what_to_print += suffix_to_add_to_what_to_print + what_to_print += ")" # This code doesn't work with more than a single output new_line = f"%{i} = {what_to_print}" # Manage datatypes if show_data_types: - new_line = f"{new_line: <40s} # {output_data_type_to_string(node)}" + new_line = f"{new_line: <50s} # {output_data_type_to_string(node)}" returned_str += f"{new_line}\n" diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index f0c5aa948..eb155c776 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -360,11 +360,11 @@ def test_small_inputset_treat_warnings_as_errors(): (4,), # Remark that, when you do the dot of tensors of 4 values between 0 and 3, # you can get a maximal value of 4*3*3 = 36, ie something on 6 bits - "%0 = x " + "%0 = x " "# EncryptedTensor, shape=(4,)>" - "\n%1 = y " + "\n%1 = y " "# EncryptedTensor, shape=(4,)>" - "\n%2 = Dot(%0, %1) " + "\n%2 = Dot(%0, %1) " "# EncryptedScalar>" "\nreturn(%2)\n", ), diff --git a/tests/numpy/test_debugging.py b/tests/numpy/test_debugging.py index b19ab4aff..267973841 100644 --- a/tests/numpy/test_debugging.py +++ b/tests/numpy/test_debugging.py @@ -112,6 +112,14 @@ def issue_130_c(x, y): issue_130_c, "%0 = Constant(1)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2, %2)\n", ), + ( + lambda x, y: numpy.arctan2(x, 42) + y, + "%0 = y\n%1 = x\n%2 = np.arctan2(%1, 42)\n%3 = Add(%2, %0)\nreturn(%3)\n", + ), + ( + lambda x, y: numpy.arctan2(43, x) + y, + "%0 = y\n%1 = x\n%2 = np.arctan2(43, %1)\n%3 = Add(%2, %0)\nreturn(%3)\n", + ), ], ) @pytest.mark.parametrize( @@ -221,9 +229,12 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): EncryptedScalar(Integer(64, is_signed=False)), EncryptedScalar(Integer(32, is_signed=True)), ), - "%0 = x # EncryptedScalar>" - "\n%1 = y # EncryptedScalar>" - "\n%2 = Add(%0, %1) # EncryptedScalar>" + "%0 = x " + "# EncryptedScalar>" + "\n%1 = y " + " # EncryptedScalar>" + "\n%2 = Add(%0, %1) " + " # EncryptedScalar>" "\nreturn(%2)\n", ), ( @@ -232,11 +243,11 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): EncryptedScalar(Integer(17, is_signed=False)), EncryptedScalar(Integer(23, is_signed=False)), ), - "%0 = x " + "%0 = x " "# EncryptedScalar>" - "\n%1 = y " + "\n%1 = y " "# EncryptedScalar>" - "\n%2 = Mul(%0, %1) " + "\n%2 = Mul(%0, %1) " "# EncryptedScalar>" "\nreturn(%2)\n", ), @@ -262,37 +273,37 @@ def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str): ( lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x " + "%0 = x " "# EncryptedScalar>" - "\n%1 = TLU(%0) " + "\n%1 = TLU(%0) " "# EncryptedScalar>" "\nreturn(%1)\n", ), ( lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4], {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x " + "%0 = x " "# EncryptedScalar>" - "\n%1 = Constant(4) " + "\n%1 = Constant(4) " "# ClearScalar>" - "\n%2 = Add(%0, %1) " + "\n%2 = Add(%0, %1) " "# EncryptedScalar>" - "\n%3 = TLU(%2) " + "\n%3 = TLU(%2) " "# EncryptedScalar>" "\nreturn(%3)\n", ), ( lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]], {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x " + "%0 = x " "# EncryptedScalar>" - "\n%1 = Constant(4) " + "\n%1 = Constant(4) " "# ClearScalar>" - "\n%2 = Add(%0, %1) " + "\n%2 = Add(%0, %1) " "# EncryptedScalar>" - "\n%3 = TLU(%2) " + "\n%3 = TLU(%2) " "# EncryptedScalar>" - "\n%4 = TLU(%3) " + "\n%4 = TLU(%3) " "# EncryptedScalar>" "\nreturn(%4)\n", ), @@ -311,3 +322,31 @@ def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_ f"==================\nExpected \n{ref_graph_str}" f"==================\n" ) + + +def test_numpy_long_constant(): + "Test get_printable_graph with long constant" + + def all_explicit_operations(x): + intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10)) + intermediate = numpy.subtract(intermediate, numpy.arange(10).reshape(1, 10)) + intermediate = numpy.arctan2(numpy.arange(10, 20).reshape(1, 10), intermediate) + intermediate = numpy.arctan2(numpy.arange(100, 200).reshape(10, 10), intermediate) + return intermediate + + op_graph = tracing.trace_numpy_function( + all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(10, 10))} + ) + + expected = """ +%0 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor, shape=(1, 10)> +%1 = x # EncryptedTensor, shape=(10, 10)> +%2 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor, shape=(10, 10)> +%3 = Add(%1, %2) # EncryptedTensor, shape=(10, 10)> +%4 = Sub(%3, %0) # EncryptedTensor, shape=(10, 10)> +%5 = np.arctan2([[10 11 12 ... 17 18 19]], %4) # EncryptedTensor, shape=(10, 10)> +%6 = np.arctan2([[100 101 ... 198 199]], %5) # EncryptedTensor, shape=(10, 10)> +return(%6) +""".lstrip() # noqa: E501 + + assert get_printable_graph(op_graph, show_data_types=True) == expected diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index e07f73fb9..c7fa35a3e 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -188,21 +188,21 @@ def test_numpy_tracing_tensors(): ) expected = """ -%0 = Constant([[2 1] [1 2]]) # ClearTensor, shape=(2, 2)> -%1 = Constant([[1 2] [2 1]]) # ClearTensor, shape=(2, 2)> -%2 = Constant([[10 20] [30 40]]) # ClearTensor, shape=(2, 2)> -%3 = Constant([[100 200] [300 400]]) # ClearTensor, shape=(2, 2)> -%4 = Constant([[5 6] [7 8]]) # ClearTensor, shape=(2, 2)> -%5 = x # EncryptedTensor, shape=(2, 2)> -%6 = Constant([[1 2] [3 4]]) # ClearTensor, shape=(2, 2)> -%7 = Add(%5, %6) # EncryptedTensor, shape=(2, 2)> -%8 = Add(%4, %7) # EncryptedTensor, shape=(2, 2)> -%9 = Sub(%3, %8) # EncryptedTensor, shape=(2, 2)> -%10 = Sub(%9, %2) # EncryptedTensor, shape=(2, 2)> -%11 = Mul(%10, %1) # EncryptedTensor, shape=(2, 2)> -%12 = Mul(%0, %11) # EncryptedTensor, shape=(2, 2)> +%0 = Constant([[2 1] [1 2]]) # ClearTensor, shape=(2, 2)> +%1 = Constant([[1 2] [2 1]]) # ClearTensor, shape=(2, 2)> +%2 = Constant([[10 20] [30 40]]) # ClearTensor, shape=(2, 2)> +%3 = Constant([[100 200] [300 400]]) # ClearTensor, shape=(2, 2)> +%4 = Constant([[5 6] [7 8]]) # ClearTensor, shape=(2, 2)> +%5 = x # EncryptedTensor, shape=(2, 2)> +%6 = Constant([[1 2] [3 4]]) # ClearTensor, shape=(2, 2)> +%7 = Add(%5, %6) # EncryptedTensor, shape=(2, 2)> +%8 = Add(%4, %7) # EncryptedTensor, shape=(2, 2)> +%9 = Sub(%3, %8) # EncryptedTensor, shape=(2, 2)> +%10 = Sub(%9, %2) # EncryptedTensor, shape=(2, 2)> +%11 = Mul(%10, %1) # EncryptedTensor, shape=(2, 2)> +%12 = Mul(%0, %11) # EncryptedTensor, shape=(2, 2)> return(%12) -""".lstrip() +""".lstrip() # noqa: E501 assert get_printable_graph(op_graph, show_data_types=True) == expected @@ -227,21 +227,21 @@ def test_numpy_explicit_tracing_tensors(): ) expected = """ -%0 = Constant([[2 1] [1 2]]) # ClearTensor, shape=(2, 2)> -%1 = Constant([[1 2] [2 1]]) # ClearTensor, shape=(2, 2)> -%2 = Constant([[10 20] [30 40]]) # ClearTensor, shape=(2, 2)> -%3 = Constant([[100 200] [300 400]]) # ClearTensor, shape=(2, 2)> -%4 = Constant([[5 6] [7 8]]) # ClearTensor, shape=(2, 2)> -%5 = x # EncryptedTensor, shape=(2, 2)> -%6 = Constant([[1 2] [3 4]]) # ClearTensor, shape=(2, 2)> -%7 = Add(%5, %6) # EncryptedTensor, shape=(2, 2)> -%8 = Add(%4, %7) # EncryptedTensor, shape=(2, 2)> -%9 = Sub(%3, %8) # EncryptedTensor, shape=(2, 2)> -%10 = Sub(%9, %2) # EncryptedTensor, shape=(2, 2)> -%11 = Mul(%10, %1) # EncryptedTensor, shape=(2, 2)> -%12 = Mul(%0, %11) # EncryptedTensor, shape=(2, 2)> +%0 = Constant([[2 1] [1 2]]) # ClearTensor, shape=(2, 2)> +%1 = Constant([[1 2] [2 1]]) # ClearTensor, shape=(2, 2)> +%2 = Constant([[10 20] [30 40]]) # ClearTensor, shape=(2, 2)> +%3 = Constant([[100 200] [300 400]]) # ClearTensor, shape=(2, 2)> +%4 = Constant([[5 6] [7 8]]) # ClearTensor, shape=(2, 2)> +%5 = x # EncryptedTensor, shape=(2, 2)> +%6 = Constant([[1 2] [3 4]]) # ClearTensor, shape=(2, 2)> +%7 = Add(%5, %6) # EncryptedTensor, shape=(2, 2)> +%8 = Add(%4, %7) # EncryptedTensor, shape=(2, 2)> +%9 = Sub(%3, %8) # EncryptedTensor, shape=(2, 2)> +%10 = Sub(%9, %2) # EncryptedTensor, shape=(2, 2)> +%11 = Mul(%10, %1) # EncryptedTensor, shape=(2, 2)> +%12 = Mul(%0, %11) # EncryptedTensor, shape=(2, 2)> return(%12) -""".lstrip() +""".lstrip() # noqa: E501 assert get_printable_graph(op_graph, show_data_types=True) == expected