From 150d33ba4898fdfd6025c8a290bad3f8fe488e67 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Fri, 3 Sep 2021 09:47:34 +0200 Subject: [PATCH] feat: make get_printable_graph give correct info for np.dot closes #204 --- hdk/common/debugging/printing.py | 8 +++++- hdk/common/values/base.py | 3 +- hdk/common/values/scalars.py | 4 +++ hdk/common/values/tensors.py | 4 +++ tests/numpy/test_compile.py | 9 ++++-- tests/numpy/test_debugging.py | 48 +++++++++++++++++++++----------- 6 files changed, 53 insertions(+), 23 deletions(-) diff --git a/hdk/common/debugging/printing.py b/hdk/common/debugging/printing.py index 1c721472a..8cce142b1 100644 --- a/hdk/common/debugging/printing.py +++ b/hdk/common/debugging/printing.py @@ -18,7 +18,7 @@ def output_data_type_to_string(node): str: a string representing the datatypes of the outputs of the node """ - return ", ".join([str(o.data_type) for o in node.outputs]) + return ", ".join([str(o) for o in node.outputs]) def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: @@ -43,6 +43,11 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: for node in nx.topological_sort(graph): + # This code doesn't work with more than a single output. For more outputs, + # we would need to change the way the destination are created: currently, + # they only are done by incrementing i + assert len(node.outputs) == 1 + if isinstance(node, ir.Input): what_to_print = node.input_name elif isinstance(node, ir.Constant): @@ -74,6 +79,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: # Then, just print the predecessors in the right order what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")" + # This code doesn't work with more than a single output new_line = f"%{i} = {what_to_print}" # Manage datatypes diff --git a/hdk/common/values/base.py b/hdk/common/values/base.py index c4ffd16dd..39fd770ef 100644 --- a/hdk/common/values/base.py +++ b/hdk/common/values/base.py @@ -17,8 +17,7 @@ class BaseValue(ABC): self._is_encrypted = is_encrypted def __repr__(self) -> str: # pragma: no cover - encrypted_str = "Encrypted" if self._is_encrypted else "Clear" - return f"{encrypted_str}{self.__class__.__name__}<{self.data_type!r}>" + return str(self) @abstractmethod def __eq__(self, other: object) -> bool: diff --git a/hdk/common/values/scalars.py b/hdk/common/values/scalars.py index c408b1637..38a9a87e4 100644 --- a/hdk/common/values/scalars.py +++ b/hdk/common/values/scalars.py @@ -10,6 +10,10 @@ class ScalarValue(BaseValue): def __eq__(self, other: object) -> bool: return BaseValue.__eq__(self, other) + def __str__(self) -> str: # pragma: no cover + encrypted_str = "Encrypted" if self._is_encrypted else "Clear" + return f"{encrypted_str}Scalar<{self.data_type!r}>" + def make_clear_scalar(data_type: BaseDataType) -> ScalarValue: """Helper to create a clear ScalarValue. diff --git a/hdk/common/values/tensors.py b/hdk/common/values/tensors.py index 1f4b9effa..966f7c8e7 100644 --- a/hdk/common/values/tensors.py +++ b/hdk/common/values/tensors.py @@ -35,6 +35,10 @@ class TensorValue(BaseValue): and super().__eq__(other) ) + def __str__(self) -> str: + encrypted_str = "Encrypted" if self._is_encrypted else "Clear" + return f"{encrypted_str}Tensor<{str(self.data_type)}, shape={self.shape}>" + @property def shape(self) -> Tuple[int, ...]: """The TensorValue shape property. diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 33db2540b..447445f0c 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -213,9 +213,12 @@ def test_fail_compile(function, input_ranges, list_of_arg_names): (4,), # Remark that, when you do the dot of tensors of 4 values between 0 and 3, # you can get a maximal value of 4*3*3 = 36, ie something on 6 bits - "%0 = x # Integer" - "\n%1 = y # Integer" - "\n%2 = Dot(0, 1) # Integer" + "%0 = x " + "# EncryptedTensor, shape=(4,)>" + "\n%1 = y " + "# EncryptedTensor, shape=(4,)>" + "\n%2 = Dot(0, 1) " + "# EncryptedScalar>" "\nreturn(%2)\n", ), # pylint: enable=unnecessary-lambda diff --git a/tests/numpy/test_debugging.py b/tests/numpy/test_debugging.py index d6099a42d..2b791c56e 100644 --- a/tests/numpy/test_debugging.py +++ b/tests/numpy/test_debugging.py @@ -221,9 +221,9 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): EncryptedScalar(Integer(64, is_signed=False)), EncryptedScalar(Integer(32, is_signed=True)), ), - "%0 = x # Integer" - "\n%1 = y # Integer" - "\n%2 = Add(0, 1) # Integer" + "%0 = x # EncryptedScalar>" + "\n%1 = y # EncryptedScalar>" + "\n%2 = Add(0, 1) # EncryptedScalar>" "\nreturn(%2)\n", ), ( @@ -232,9 +232,12 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): EncryptedScalar(Integer(17, is_signed=False)), EncryptedScalar(Integer(23, is_signed=False)), ), - "%0 = x # Integer" - "\n%1 = y # Integer" - "\n%2 = Mul(0, 1) # Integer" + "%0 = x " + "# EncryptedScalar>" + "\n%1 = y " + "# EncryptedScalar>" + "\n%2 = Mul(0, 1) " + "# EncryptedScalar>" "\nreturn(%2)\n", ), ], @@ -259,27 +262,38 @@ def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str): ( lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x # Integer" - "\n%1 = TLU(0) # Integer" + "%0 = x " + "# EncryptedScalar>" + "\n%1 = TLU(0) " + "# EncryptedScalar>" "\nreturn(%1)\n", ), ( lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4], {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x # Integer" - "\n%1 = Constant(4) # Integer" - "\n%2 = Add(0, 1) # Integer" - "\n%3 = TLU(2) # Integer" + "%0 = x " + "# EncryptedScalar>" + "\n%1 = Constant(4) " + "# ClearScalar>" + "\n%2 = Add(0, 1) " + "# EncryptedScalar>" + "\n%3 = TLU(2) " + "# EncryptedScalar>" "\nreturn(%3)\n", ), ( lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]], {"x": EncryptedScalar(Integer(2, is_signed=False))}, - "%0 = x # Integer" - "\n%1 = Constant(4) # Integer" - "\n%2 = Add(0, 1) # Integer" - "\n%3 = TLU(2) # Integer" - "\n%4 = TLU(3) # Integer" + "%0 = x " + "# EncryptedScalar>" + "\n%1 = Constant(4) " + "# ClearScalar>" + "\n%2 = Add(0, 1) " + "# EncryptedScalar>" + "\n%3 = TLU(2) " + "# EncryptedScalar>" + "\n%4 = TLU(3) " + "# EncryptedScalar>" "\nreturn(%4)\n", ), ],