feat: make get_printable_graph give correct info for np.dot

closes #204
This commit is contained in:
Benoit Chevallier-Mames
2021-09-03 09:47:34 +02:00
committed by Benoit Chevallier
parent 6b6aa7ee4e
commit 150d33ba48
6 changed files with 53 additions and 23 deletions

View File

@@ -18,7 +18,7 @@ def output_data_type_to_string(node):
str: a string representing the datatypes of the outputs of the node
"""
return ", ".join([str(o.data_type) for o in node.outputs])
return ", ".join([str(o) for o in node.outputs])
def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
@@ -43,6 +43,11 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
for node in nx.topological_sort(graph):
# This code doesn't work with more than a single output. For more outputs,
# we would need to change the way the destination are created: currently,
# they only are done by incrementing i
assert len(node.outputs) == 1
if isinstance(node, ir.Input):
what_to_print = node.input_name
elif isinstance(node, ir.Constant):
@@ -74,6 +79,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
# Then, just print the predecessors in the right order
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
# This code doesn't work with more than a single output
new_line = f"%{i} = {what_to_print}"
# Manage datatypes

View File

@@ -17,8 +17,7 @@ class BaseValue(ABC):
self._is_encrypted = is_encrypted
def __repr__(self) -> str: # pragma: no cover
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
return f"{encrypted_str}{self.__class__.__name__}<{self.data_type!r}>"
return str(self)
@abstractmethod
def __eq__(self, other: object) -> bool:

View File

@@ -10,6 +10,10 @@ class ScalarValue(BaseValue):
def __eq__(self, other: object) -> bool:
return BaseValue.__eq__(self, other)
def __str__(self) -> str: # pragma: no cover
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
return f"{encrypted_str}Scalar<{self.data_type!r}>"
def make_clear_scalar(data_type: BaseDataType) -> ScalarValue:
"""Helper to create a clear ScalarValue.

View File

@@ -35,6 +35,10 @@ class TensorValue(BaseValue):
and super().__eq__(other)
)
def __str__(self) -> str:
encrypted_str = "Encrypted" if self._is_encrypted else "Clear"
return f"{encrypted_str}Tensor<{str(self.data_type)}, shape={self.shape}>"
@property
def shape(self) -> Tuple[int, ...]:
"""The TensorValue shape property.

View File

@@ -213,9 +213,12 @@ def test_fail_compile(function, input_ranges, list_of_arg_names):
(4,),
# Remark that, when you do the dot of tensors of 4 values between 0 and 3,
# you can get a maximal value of 4*3*3 = 36, ie something on 6 bits
"%0 = x # Integer<unsigned, 6 bits>"
"\n%1 = y # Integer<unsigned, 6 bits>"
"\n%2 = Dot(0, 1) # Integer<unsigned, 6 bits>"
"%0 = x "
"# EncryptedTensor<Integer<unsigned, 6 bits>, shape=(4,)>"
"\n%1 = y "
"# EncryptedTensor<Integer<unsigned, 6 bits>, shape=(4,)>"
"\n%2 = Dot(0, 1) "
"# EncryptedScalar<Integer<unsigned, 6 bits>>"
"\nreturn(%2)\n",
),
# pylint: enable=unnecessary-lambda

View File

@@ -221,9 +221,9 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
EncryptedScalar(Integer(64, is_signed=False)),
EncryptedScalar(Integer(32, is_signed=True)),
),
"%0 = x # Integer<unsigned, 64 bits>"
"\n%1 = y # Integer<signed, 32 bits>"
"\n%2 = Add(0, 1) # Integer<signed, 65 bits>"
"%0 = x # EncryptedScalar<Integer<unsigned, 64 bits>>"
"\n%1 = y # EncryptedScalar<Integer<signed, 32 bits>>"
"\n%2 = Add(0, 1) # EncryptedScalar<Integer<signed, 65 bits>>"
"\nreturn(%2)\n",
),
(
@@ -232,9 +232,12 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
EncryptedScalar(Integer(17, is_signed=False)),
EncryptedScalar(Integer(23, is_signed=False)),
),
"%0 = x # Integer<unsigned, 17 bits>"
"\n%1 = y # Integer<unsigned, 23 bits>"
"\n%2 = Mul(0, 1) # Integer<unsigned, 23 bits>"
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 17 bits>>"
"\n%1 = y "
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
"\n%2 = Mul(0, 1) "
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
"\nreturn(%2)\n",
),
],
@@ -259,27 +262,38 @@ def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = TLU(0) # Integer<unsigned, 4 bits>"
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%1 = TLU(0) "
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
"\nreturn(%1)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = Constant(4) # Integer<unsigned, 3 bits>"
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
"\n%3 = TLU(2) # Integer<unsigned, 2 bits>"
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%1 = Constant(4) "
"# ClearScalar<Integer<unsigned, 3 bits>>"
"\n%2 = Add(0, 1) "
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
"\n%3 = TLU(2) "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\nreturn(%3)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = Constant(4) # Integer<unsigned, 3 bits>"
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
"\n%3 = TLU(2) # Integer<unsigned, 2 bits>"
"\n%4 = TLU(3) # Integer<unsigned, 4 bits>"
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%1 = Constant(4) "
"# ClearScalar<Integer<unsigned, 3 bits>>"
"\n%2 = Add(0, 1) "
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
"\n%3 = TLU(2) "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%4 = TLU(3) "
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
"\nreturn(%4)\n",
),
],