refactor(debugging): re-write graph formatting

This commit is contained in:
Umut
2021-11-10 16:49:17 +03:00
parent b449ddc360
commit 6fec590e65
25 changed files with 653 additions and 984 deletions

View File

@@ -85,7 +85,7 @@ class CompilationArtifacts:
"""
drawing = draw_graph(operation_graph)
textual_representation = format_operation_graph(operation_graph, show_data_types=True)
textual_representation = format_operation_graph(operation_graph)
self.drawings_of_operation_graphs[name] = drawing
self.textual_representations_of_operation_graphs[name] = textual_representation

View File

@@ -21,6 +21,9 @@ class Float(base.BaseDataType):
def __repr__(self) -> str:
return f"{self.__class__.__name__}<{self.bit_width} bits>"
def __str__(self) -> str:
return f"float{self.bit_width}"
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.bit_width == other.bit_width

View File

@@ -23,6 +23,9 @@ class Integer(base.BaseDataType):
signed_str = "signed" if self.is_signed else "unsigned"
return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>"
def __str__(self) -> str:
return f"{('int' if self.is_signed else 'uint')}{self.bit_width}"
def __eq__(self, other: object) -> bool:
return (
isinstance(other, self.__class__)

View File

@@ -1,4 +1,4 @@
"""Module for debugging."""
from .custom_assert import assert_true
from .drawing import draw_graph
from .printing import format_operation_graph
from .formatting import format_operation_graph

View File

@@ -84,7 +84,7 @@ def draw_graph(
attributes = {
node: {
"label": node.label(),
"label": node.text_for_drawing(),
"color": get_color(node, output_nodes),
"penwidth": 2, # double thickness for circles
"peripheries": 2 if node in output_nodes else 1, # double circle for output nodes

View File

@@ -0,0 +1,124 @@
"""Functions to format operation graphs for debugging purposes."""
from typing import Dict, List, Optional, Tuple
import networkx as nx
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import IntermediateNode
def format_operation_graph(
opgraph: OPGraph,
maximum_constant_length: int = 25,
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
) -> str:
"""Format an operation graph.
Args:
opgraph (OPGraph):
the operation graph to format
maximum_constant_length (int):
maximum length of the constant throughout the formatting
highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]] = None):
the dict of nodes and their corresponding messages which will be highlighted
Returns:
str: formatted operation graph
"""
assert_true(isinstance(opgraph, OPGraph))
# (node, output_index) -> identifier
# e.g., id_map[(node1, 0)] = 2 and id_map[(node1, 1)] = 3
# means line for node1 is in this form (%2, %3) = node1.format(...)
id_map: Dict[Tuple[IntermediateNode, int], int] = {}
# lines that will be merged at the end
lines: List[str] = []
# type information to add to each line (for alingment, this is done after lines are determined)
type_informations: List[str] = []
# default highlighted nodes is empty
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
# highlight information for lines, this is required because highlights are added to lines
# after their type information is added and we only have line numbers, not nodes
highlighted_lines: Dict[int, List[str]] = {}
for node in nx.topological_sort(opgraph.graph):
# assign a unique id to outputs of node
assert_true(len(node.outputs) > 0)
for i in range(len(node.outputs)):
id_map[(node, i)] = len(id_map)
# remember highlights of the node
if node in highlighted_nodes:
highlighted_lines[len(lines)] = highlighted_nodes[node]
# extract predecessors and their ids
predecessors = []
for predecessor, output_idx in opgraph.get_ordered_inputs_of(node):
predecessors.append(f"%{id_map[(predecessor, output_idx)]}")
# start the build the line for the node
line = ""
# add output information to the line
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
line += outputs if len(node.outputs) == 1 else f"({outputs})"
# add node information to the line
line += " = "
line += node.text_for_formatting(predecessors, maximum_constant_length)
# append line to list of lines
lines.append(line)
# remember type information of the node
types = ", ".join(str(output) for output in node.outputs)
type_informations.append(types if len(node.outputs) == 1 else f"({types})")
# align = signs
#
# e.g.,
#
# %1 = ...
# %2 = ...
# ...
# %8 = ...
# %9 = ...
# %10 = ...
# %11 = ...
# ...
longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines)
for i, line in enumerate(lines):
length_before_equals_sign = len(line.split("=")[0])
lines[i] = (" " * (longest_length_before_equals_sign - length_before_equals_sign)) + line
# add type informations
longest_line_length = max(len(line) for line in lines)
for i, line in enumerate(lines):
lines[i] += " " * (longest_line_length - len(line))
lines[i] += f" # {type_informations[i]}"
# add highlights (this is done in reverse to keep indices consistent)
for i in reversed(range(len(lines))):
if i in highlighted_lines:
for j, message in enumerate(highlighted_lines[i]):
highlight = "^" if j == 0 else " "
lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}")
# add return information
# (if there is a single return, it's in the form `return %id`
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
returns: List[str] = []
for node in opgraph.output_nodes.values():
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
returns.append(outputs if len(node.outputs) == 1 else f"({outputs})")
lines.append("return " + returns[0] if len(returns) == 1 else f"({', '.join(returns)})")
return "\n".join(lines)

View File

@@ -1,146 +0,0 @@
"""functions to print the different graphs we can generate in the package, eg to debug."""
from typing import Any, Dict, List, Optional
import networkx as nx
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import (
Constant,
GenericFunction,
IndexConstant,
Input,
IntermediateNode,
)
def output_data_type_to_string(node):
"""Return the datatypes of the outputs of the node.
Args:
node: a graph node
Returns:
str: a string representing the datatypes of the outputs of the node
"""
return ", ".join([str(o) for o in node.outputs])
def shorten_a_constant(constant_data: str):
"""Return a constant (if small) or an extra of the constant (if too large).
Args:
constant (str): The constant we want to shorten
Returns:
str: a string to represent the constant
"""
content = str(constant_data).replace("\n", "")
# if content is longer than 25 chars, only show the first and the last 10 chars of it
# 25 is selected using the spaces available before data type information
short_content = f"{content[:10]} ... {content[-10:]}" if len(content) > 25 else content
return short_content
def format_operation_graph(
opgraph: OPGraph,
show_data_types: bool = False,
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
) -> str:
"""Return a string representing a graph.
Args:
opgraph (OPGraph): The graph that we want to draw
show_data_types (bool, optional): Whether or not showing data_types of nodes, eg to see
their width. Defaults to False.
highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]], optional): The dict of nodes
which will be highlighted and their corresponding messages. Defaults to None.
Returns:
str: a string to print or save in a file
"""
assert_true(isinstance(opgraph, OPGraph))
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values())
graph = opgraph.graph
returned_str = ""
i = 0
map_table: Dict[Any, int] = {}
for node in nx.topological_sort(graph):
# TODO: #640
# This code doesn't work with more than a single output. For more outputs,
# we would need to change the way the destination are created: currently,
# they only are done by incrementing i
assert_true(len(node.outputs) == 1)
if isinstance(node, Input):
what_to_print = node.input_name
elif isinstance(node, Constant):
to_show = shorten_a_constant(node.constant_data)
what_to_print = f"Constant({to_show})"
else:
base_name = node.__class__.__name__
if isinstance(node, GenericFunction):
base_name = node.op_name
what_to_print = base_name + "("
# Find all the names of the current predecessors of the node
list_of_arg_name = []
for pred, index_list in graph.pred[node].items():
for index in index_list.values():
# Remark that we keep the index of the predecessor and its
# name, to print sources in the right order, which is
# important for eg non commutative operations
list_of_arg_name += [(index["input_idx"], str(map_table[pred]))]
# Some checks, because the previous algorithm is not clear
assert_true(len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name)))
list_of_arg_name.sort()
assert_true([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name))))
prefix_to_add_to_what_to_print = ""
suffix_to_add_to_what_to_print = ""
# Then, just print the predecessors in the right order
what_to_print += prefix_to_add_to_what_to_print
what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name])
what_to_print += suffix_to_add_to_what_to_print
what_to_print += (
f"{node.label().replace('value', '')}" if isinstance(node, IndexConstant) else ""
)
what_to_print += ")"
# This code doesn't work with more than a single output
new_line = f"%{i} = {what_to_print}"
# Manage datatypes
if show_data_types:
new_line = f"{new_line: <50s} # {output_data_type_to_string(node)}"
returned_str += f"{new_line}\n"
if node in highlighted_nodes:
new_line_len = len(new_line)
message = f"\n{' ' * new_line_len} ".join(highlighted_nodes[node])
returned_str += f"{'^' * new_line_len} {message}\n"
map_table[node] = i
i += 1
return_part = ", ".join(["%" + str(map_table[n]) for n in list_of_nodes_which_are_outputs])
returned_str += f"return({return_part})\n"
return returned_str

View File

@@ -30,6 +30,9 @@ class LookupTable:
self.table = table
self.output_dtype = make_integer_to_hold(table, force_signed=False)
def __repr__(self):
return str(list(self.table))
def __getitem__(self, key: Union[int, Iterable, BaseTracer]):
# if a tracer is used for indexing,
# we need to create an `GenericFunction` node

View File

@@ -21,7 +21,7 @@ class FHECircuit:
self.engine = engine
def __str__(self):
return format_operation_graph(self.opgraph, show_data_types=True)
return format_operation_graph(self.opgraph)
def draw(
self,

View File

@@ -0,0 +1,47 @@
"""Helpers for formatting functionality."""
from typing import Any, Dict, Hashable
import numpy
from ..debugging.custom_assert import assert_true
SPECIAL_OBJECT_MAPPING: Dict[Any, str] = {
numpy.float32: "float32",
numpy.float64: "float64",
numpy.int8: "int8",
numpy.int16: "int16",
numpy.int32: "int32",
numpy.int64: "int64",
numpy.uint8: "uint8",
numpy.uint16: "uint16",
numpy.uint32: "uint32",
numpy.uint64: "uint64",
}
def format_constant(constant: Any, maximum_length: int = 45) -> str:
"""Format a constant.
Args:
constant (Any): the constant to format
maximum_length (int): maximum length of the resulting string
Returns:
str: the formatted constant
"""
if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING:
return SPECIAL_OBJECT_MAPPING[constant]
# maximum_length should not be smaller than 7 characters because
# the constant will be formatted to `x ... y`
# where x and y are part of the constant and they are at least 1 character
assert_true(maximum_length >= 7)
content = str(constant).replace("\n", "")
if len(content) > maximum_length:
from_start = (maximum_length - 5) // 2
from_end = (maximum_length - 5) - from_start
content = f"{content[:from_start]} ... {content[-from_end:]}"
return content

View File

@@ -96,9 +96,7 @@ def check_node_compatibility_with_mlir(
# e.g., `np.absolute is not supported for the time being`
return f"{node.op_name} is not supported for the time being"
else:
return (
f"{node.op_name} of kind {node.op_kind.value} is not supported for the time being"
)
return f"{node.op_name} is not supported for the time being"
elif isinstance(node, intermediate.Dot): # constraints for dot product
assert_true(len(inputs) == 2)
@@ -206,10 +204,8 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
raise RuntimeError(
f"max_bit_width of some nodes is too high for the current version of "
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) "
f"which is not compatible with:\n"
+ format_operation_graph(
op_graph, show_data_types=True, highlighted_nodes=offending_nodes
)
f"which is not compatible with:\n\n"
+ format_operation_graph(op_graph, highlighted_nodes=offending_nodes)
)
_set_all_bit_width(op_graph, max_bit_width)

View File

@@ -131,6 +131,24 @@ class OPGraph:
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
def get_ordered_inputs_of(self, node: IntermediateNode) -> List[Tuple[IntermediateNode, int]]:
"""Get node inputs ordered by their indices.
Args:
node (IntermediateNode): the node for which we want the ordered inputs
Returns:
List[Tuple[IntermediateNode, int]]: the ordered list of inputs
"""
idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {}
for pred in self.graph.pred[node]:
edge_data = self.graph.get_edge_data(pred, node)
idx_to_inp.update(
(data["input_idx"], (pred, data["output_idx"])) for data in edge_data.values()
)
return [idx_to_inp[i] for i in range(len(idx_to_inp))]
def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]:
"""Evaluate a graph and get intermediate values for all nodes.

View File

@@ -133,10 +133,9 @@ def convert_float_subgraph_to_fused_node(
printable_graph = format_operation_graph(
float_subgraph_as_op_graph,
show_data_types=True,
highlighted_nodes=node_with_issues_for_fusing,
)
message = f"The following subgraph is not fusable:\n{printable_graph}"
message = f"The following subgraph is not fusable:\n\n{printable_graph}"
logger.warning(message)
return None

View File

@@ -16,6 +16,7 @@ from ..data_types.dtypes_helpers import (
from ..data_types.integers import Integer
from ..debugging.custom_assert import assert_true
from ..helpers import indexing_helpers
from ..helpers.formatting_helpers import format_constant
from ..helpers.python_helpers import catch, update_and_return_dict
from ..values import (
BaseValue,
@@ -64,6 +65,27 @@ class IntermediateNode(ABC):
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
"""Get the formatted node (used in formatting opgraph).
Args:
predecessors (List[str]): predecessor names to this node
_maximum_constant_length (int): desired maximum constant length
Returns:
str: the formatted node
"""
return f"{self.__class__.__name__.lower()}({', '.join(predecessors)})"
@abstractmethod
def text_for_drawing(self) -> str:
"""Get the label of the node (used in drawing opgraph).
Returns:
str: the label of the node
"""
@abstractmethod
def evaluate(self, inputs: Dict[int, Any]) -> Any:
"""Simulate what the represented computation would output for the given inputs.
@@ -93,15 +115,6 @@ class IntermediateNode(ABC):
"""
return cls.n_in() > 1
@abstractmethod
def label(self) -> str:
"""Get the label of the node.
Returns:
str: the label of the node
"""
class Add(IntermediateNode):
"""Addition between two values."""
@@ -110,12 +123,12 @@ class Add(IntermediateNode):
__init__ = IntermediateNode._init_binary
def text_for_drawing(self) -> str:
return "+"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] + inputs[1]
def label(self) -> str:
return "+"
class Sub(IntermediateNode):
"""Subtraction between two values."""
@@ -124,12 +137,12 @@ class Sub(IntermediateNode):
__init__ = IntermediateNode._init_binary
def text_for_drawing(self) -> str:
return "-"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] - inputs[1]
def label(self) -> str:
return "-"
class Mul(IntermediateNode):
"""Multiplication between two values."""
@@ -138,12 +151,12 @@ class Mul(IntermediateNode):
__init__ = IntermediateNode._init_binary
def text_for_drawing(self) -> str:
return "*"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] * inputs[1]
def label(self) -> str:
return "*"
class Input(IntermediateNode):
"""Node representing an input of the program."""
@@ -164,12 +177,16 @@ class Input(IntermediateNode):
self.program_input_idx = program_input_idx
self.outputs = [deepcopy(self.inputs[0])]
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
assert_true(len(predecessors) == 0)
return self.input_name
def text_for_drawing(self) -> str:
return self.input_name
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0]
def label(self) -> str:
return self.input_name
class Constant(IntermediateNode):
"""Node representing a constant of the program."""
@@ -191,6 +208,13 @@ class Constant(IntermediateNode):
self._constant_data = constant_data
self.outputs = [base_value_class(is_encrypted=False)]
def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str:
assert_true(len(predecessors) == 0)
return format_constant(self.constant_data, maximum_constant_length)
def text_for_drawing(self) -> str:
return format_constant(self.constant_data)
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return self.constant_data
@@ -203,9 +227,6 @@ class Constant(IntermediateNode):
"""
return self._constant_data
def label(self) -> str:
return str(self.constant_data)
class IndexConstant(IntermediateNode):
"""Node representing a constant indexing in the program.
@@ -242,19 +263,17 @@ class IndexConstant(IntermediateNode):
else ClearTensor(output_dtype, output_shape)
]
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0][self.index]
def label(self) -> str:
"""Label of the node to show during drawings.
It can be used for some other places after `"value"` below is replaced by `""`.
This note will no longer be necessary after #707 is addressed.
"""
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
assert_true(len(predecessors) == 1)
elements = [indexing_helpers.format_indexing_element(element) for element in self.index]
index = ", ".join(elements)
return f"value[{index}]"
return f"{predecessors[0]}[{index}]"
def text_for_drawing(self) -> str:
return self.text_for_formatting(["value"], 0) # 0 is unused
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0][self.index]
def flood_replace_none_values(table: list):
@@ -335,15 +354,26 @@ class GenericFunction(IntermediateNode):
self.op_name = op_name if op_name is not None else self.__class__.__name__
def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str:
all_args = deepcopy(predecessors)
all_args.extend(format_constant(value, maximum_constant_length) for value in self.op_args)
all_args.extend(
f"{name}={format_constant(value, maximum_constant_length)}"
for name, value in self.op_kwargs.items()
)
return f"{self.op_name}({', '.join(all_args)})"
def text_for_drawing(self) -> str:
return self.op_name
def evaluate(self, inputs: Dict[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.arbitrary_func is not None
ordered_inputs = [inputs[idx] for idx in range(len(inputs))]
return self.arbitrary_func(*ordered_inputs, *self.op_args, **self.op_kwargs)
def label(self) -> str:
return self.op_name
def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]:
"""Get the table for the current input value of this GenericFunction.
@@ -466,14 +496,14 @@ class Dot(IntermediateNode):
self.outputs = [output_scalar_value(output_dtype)]
self.evaluation_function = delegate_evaluation_function
def text_for_drawing(self) -> str:
return "dot"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
# This is the continuation of the mypy bug workaround
assert self.evaluation_function is not None
return self.evaluation_function(inputs[0], inputs[1])
def label(self) -> str:
return "dot"
class MatMul(IntermediateNode):
"""Return the node representing a matrix multiplication."""
@@ -513,8 +543,8 @@ class MatMul(IntermediateNode):
self.outputs = [output_value]
def text_for_drawing(self) -> str:
return "matmul"
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] @ inputs[1]
def label(self) -> str:
return "@"

View File

@@ -262,9 +262,7 @@ def prepare_op_graph_for_mlir(op_graph):
if offending_nodes is not None:
raise RuntimeError(
"function you are trying to compile isn't supported for MLIR lowering\n\n"
+ format_operation_graph(
op_graph, show_data_types=True, highlighted_nodes=offending_nodes
)
+ format_operation_graph(op_graph, highlighted_nodes=offending_nodes)
)
# Update bit_width for MLIR

View File

@@ -104,7 +104,7 @@ class NPTracer(BaseTracer):
output_value=generic_function_output_value,
op_kind="TLU",
op_kwargs={"dtype": normalized_numpy_dtype.type},
op_name=f"astype({normalized_numpy_dtype})",
op_name="astype",
)
output_tracer = self.__class__([self], traced_computation=traced_computation, output_idx=0)
return output_tracer
@@ -320,7 +320,7 @@ class NPTracer(BaseTracer):
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="np.transpose",
op_name="transpose",
op_attributes={"fusable": transpose_is_fusable},
)
output_tracer = self.__class__(
@@ -367,7 +367,7 @@ class NPTracer(BaseTracer):
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs=deepcopy(kwargs),
op_name="np.ravel",
op_name="ravel",
op_attributes={"fusable": ravel_is_fusable},
)
output_tracer = self.__class__(
@@ -405,12 +405,9 @@ class NPTracer(BaseTracer):
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
newshape = deepcopy(arg1)
if isinstance(newshape, int):
# Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work, while classical form is
# numpy.reshape(x, (170,))
newshape = (newshape,)
# Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work,
# while classical form is numpy.reshape(x, (170,))
newshape = deepcopy(arg1) if not isinstance(arg1, int) else (arg1,)
# Check shape compatibility
assert_true(
@@ -435,7 +432,7 @@ class NPTracer(BaseTracer):
output_value=generic_function_output_value,
op_kind="Memory",
op_kwargs={"newshape": newshape},
op_name="np.reshape",
op_name="reshape",
op_attributes={"fusable": reshape_is_fusable},
)
output_tracer = self.__class__(
@@ -575,7 +572,7 @@ def _get_unary_fun(function: numpy.ufunc):
# dynamically
# pylint: disable=protected-access
return lambda *input_tracers, **kwargs: NPTracer._unary_operator(
function, f"np.{function.__name__}", *input_tracers, **kwargs
function, f"{function.__name__}", *input_tracers, **kwargs
)
# pylint: enable=protected-access
@@ -587,7 +584,7 @@ def _get_binary_fun(function: numpy.ufunc):
# dynamically
# pylint: disable=protected-access
return lambda *input_tracers, **kwargs: NPTracer._binary_operator(
function, f"np.{function.__name__}", *input_tracers, **kwargs
function, f"{function.__name__}", *input_tracers, **kwargs
)
# pylint: enable=protected-access

View File

@@ -328,11 +328,11 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):
captured = capsys.readouterr()
assert (
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
)
@@ -365,14 +365,14 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys):
captured = capsys.readouterr()
assert (
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint3, shape=(3,)> which is not compatible)\n"
)
@@ -436,14 +436,14 @@ def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys)
captured = capsys.readouterr()
assert (
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
"but got ClearTensor<uint3, shape=(3,)> which is not compatible)\n"
)

View File

@@ -0,0 +1,78 @@
"""Test file for formatting"""
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import format_operation_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph
def test_format_operation_graph_with_multiple_edges(default_compilation_configuration):
"""Test format_operation_graph with multiple edges"""
def function(x):
return x + x
opgraph = compile_numpy_function_into_op_graph(
function,
{"x": EncryptedScalar(Integer(4, True))},
[(i,) for i in range(0, 10)],
default_compilation_configuration,
)
formatted_graph = format_operation_graph(opgraph)
assert (
formatted_graph
== """
%0 = x # EncryptedScalar<uint4>
%1 = add(%0, %0) # EncryptedScalar<uint5>
return %1
""".strip()
)
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
"""Test format_operation_graph with offending nodes"""
def function(x):
return x + 42
opgraph = compile_numpy_function_into_op_graph(
function,
{"x": EncryptedScalar(Integer(7, True))},
[(i,) for i in range(-5, 5)],
default_compilation_configuration,
)
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
assert (
formatted_graph
== """
%0 = x # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
%1 = 42 # ClearScalar<uint6>
%2 = add(%0, %1) # EncryptedScalar<uint6>
return %2
""".strip()
)
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
assert (
formatted_graph
== """
%0 = x # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
%1 = 42 # ClearScalar<uint6>
%2 = add(%0, %1) # EncryptedScalar<uint6>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
baz
return %2
""".strip()
)

View File

@@ -1,94 +0,0 @@
"""Test file for printing"""
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import format_operation_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
"""Test format_operation_graph with offending nodes"""
def function(x):
return x + 42
opgraph = compile_numpy_function_into_op_graph(
function,
{"x": EncryptedScalar(Integer(7, True))},
[(i,) for i in range(-5, 5)],
default_compilation_configuration,
)
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
without_types = format_operation_graph(
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
).strip()
with_types = format_operation_graph(
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
).strip()
assert (
without_types
== """
%0 = x
^^^^^^ foo
%1 = Constant(42)
%2 = Add(%0, %1)
return(%2)
""".strip()
)
assert (
with_types
== """
%0 = x # EncryptedScalar<Integer<signed, 4 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
return(%2)
""".strip()
)
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
without_types = format_operation_graph(
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
).strip()
with_types = format_operation_graph(
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
).strip()
assert (
without_types
== """
%0 = x
^^^^^^ foo
%1 = Constant(42)
%2 = Add(%0, %1)
^^^^^^^^^^^^^^^^ bar
baz
return(%2)
""".strip()
)
assert (
with_types
== """
%0 = x # EncryptedScalar<Integer<signed, 4 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
baz
return(%2)
""".strip()
)

View File

@@ -157,88 +157,118 @@ def get_func_params_int32(func, scalar=True):
no_fuse_unhandled,
False,
get_func_params_int32(no_fuse_unhandled),
"""The following subgraph is not fusable:
%0 = x # EncryptedScalar<Integer<signed, 32 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
%1 = Constant(0.7) # ClearScalar<Float<64 bits>>
%2 = y # EncryptedScalar<Integer<signed, 32 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
%3 = Constant(1.3) # ClearScalar<Float<64 bits>>
%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>
%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>
%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>
%7 = astype(int32)(%6) # EncryptedScalar<Integer<signed, 32 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs
return(%7)""", # noqa: E501 # pylint: disable=line-too-long
"""
The following subgraph is not fusable:
%0 = x # EncryptedScalar<int32>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
%1 = 0.7 # ClearScalar<float64>
%2 = y # EncryptedScalar<int32>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
%3 = 1.3 # ClearScalar<float64>
%4 = add(%0, %1) # EncryptedScalar<float64>
%5 = add(%2, %3) # EncryptedScalar<float64>
%6 = add(%4, %5) # EncryptedScalar<float64>
%7 = astype(%6, dtype=int32) # EncryptedScalar<int32>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs
return %7
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_unhandled",
),
pytest.param(
no_fuse_dot,
False,
{"x": EncryptedTensor(Integer(32, True), (10,))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,)
%1 = Constant([1.33 1.33 ... 1.33 1.33]) # ClearTensor<Float<64 bits>, shape=(10,)>
%2 = Dot(%0, %1) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
%3 = astype(int32)(%2) # EncryptedScalar<Integer<signed, 32 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
"""
The following subgraph is not fusable:
%0 = x # EncryptedTensor<int32, shape=(10,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,)
%1 = [1.33 1.33 ... 1.33 1.33] # ClearTensor<float64, shape=(10,)>
%2 = dot(%0, %1) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
%3 = astype(%2, dtype=int32) # EncryptedScalar<int32>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
return %3
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_dot",
),
pytest.param(
ravel_cases,
False,
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(200,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(200,)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
"""
The following subgraph is not fusable:
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
%2 = ravel(%1) # EncryptedTensor<float64, shape=(200,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(200,)>
return %3
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_ravel",
),
pytest.param(
transpose_cases,
False,
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
"""
The following subgraph is not fusable:
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
%2 = transpose(%1) # EncryptedTensor<float64, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(20, 10)>
return %3
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_transpose",
),
pytest.param(
lambda x: reshape_cases(x, (20, 10)),
False,
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
"""The following subgraph is not fusable:
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
"""
The following subgraph is not fusable:
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
%2 = reshape(%1, newshape=(20, 10)) # EncryptedTensor<float64, shape=(20, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(20, 10)>
return %3
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_explicitely_reshape",
),
pytest.param(
no_fuse_big_constant_3_10_10,
False,
{"x": EncryptedTensor(Integer(32, True), (10, 10))},
"""The following subgraph is not fusable:
%0 = Constant([[[1. 1. 1 ... . 1. 1.]]]) # ClearTensor<Float<64 bits>, shape=(3, 10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
%2 = astype(float64)(%1) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
%3 = Add(%2, %0) # EncryptedTensor<Float<64 bits>, shape=(3, 10, 10)>
%4 = astype(int32)(%3) # EncryptedTensor<Integer<signed, 32 bits>, shape=(3, 10, 10)>
return(%4)""", # noqa: E501 # pylint: disable=line-too-long
"""
The following subgraph is not fusable:
%0 = [[[1. 1. 1 ... . 1. 1.]]] # ClearTensor<float64, shape=(3, 10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
%1 = x # EncryptedTensor<int32, shape=(10, 10)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
%2 = astype(%1, dtype=float64) # EncryptedTensor<float64, shape=(10, 10)>
%3 = add(%2, %0) # EncryptedTensor<float64, shape=(3, 10, 10)>
%4 = astype(%3, dtype=int32) # EncryptedTensor<int32, shape=(3, 10, 10)>
return %4
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_big_constant_3_10_10",
),
pytest.param(
@@ -322,7 +352,7 @@ def test_fuse_float_operations(
else:
assert fused_num_nodes == orig_num_nodes
captured = capfd.readouterr()
assert warning_message in remove_color_codes(captured.err)
assert warning_message in (output := remove_color_codes(captured.err)), output
for input_ in [0, 2, 42, 44]:
inputs = ()

View File

@@ -17,7 +17,7 @@ def test_circuit_str(default_compilation_configuration):
inputset = [(i,) for i in range(2 ** 3)]
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
assert str(circuit) == format_operation_graph(circuit.opgraph, show_data_types=True)
assert str(circuit) == format_operation_graph(circuit.opgraph)
def test_circuit_draw(default_compilation_configuration):

View File

@@ -495,7 +495,7 @@ def test_compile_function_multiple_outputs(
# when we have the converter, we can check the MLIR
draw_graph(op_graph, show=False)
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
str_of_the_graph = format_operation_graph(op_graph)
print(f"\n{str_of_the_graph}\n")
@@ -960,7 +960,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration):
default_compilation_configuration,
)
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
str_of_the_graph = format_operation_graph(op_graph)
print(f"\n{str_of_the_graph}\n")
@@ -991,14 +991,16 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
[(i,) for i in range(8)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>
%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
return(%2)
""".lstrip() # noqa: E501
%0 = 1 # ClearScalar<uint1>
%1 = x # EncryptedScalar<uint3>
%2 = sub(%0, %1) # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
return %2
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1021,16 +1023,18 @@ return(%2)
],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%1 = y # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%2 = Dot(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported
return(%2)
""".lstrip() # noqa: E501
%0 = x # EncryptedTensor<int2, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%1 = y # EncryptedTensor<int2, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%2 = dot(%0, %1) # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported
return %2
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1039,14 +1043,16 @@ return(%2)
[(numpy.random.randint(-4, 2 ** 2, size=(2, 2)),) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<signed, 3 bits>, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%1 = IndexConstant(%0[0]) # EncryptedTensor<Integer<signed, 3 bits>, shape=(2,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ indexing is not supported for the time being
return(%1)
""".lstrip() # noqa: E501
%0 = x # EncryptedTensor<int3, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%1 = %0[0] # EncryptedTensor<int3, shape=(2,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ indexing is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1055,28 +1061,30 @@ return(%1)
[(numpy.array(i), numpy.array(i)) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = Constant(1.5) # ClearScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%1 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
%2 = Constant(2.8) # ClearScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%3 = y # EncryptedScalar<Integer<unsigned, 4 bits>>
%4 = Constant(9.3) # ClearScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%5 = Add(%1, %2) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
%6 = Add(%3, %4) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
%7 = Sub(%5, %6) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported
%8 = Mul(%7, %0) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported
%9 = astype(int32)(%8) # EncryptedScalar<Integer<signed, 5 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype(int32) is not supported without fusing
return(%9)
""".lstrip() # noqa: E501
%0 = 1.5 # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%1 = x # EncryptedScalar<uint4>
%2 = 2.8 # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%3 = y # EncryptedScalar<uint4>
%4 = 9.3 # ClearScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%5 = add(%1, %2) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
%6 = add(%3, %4) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
%7 = sub(%5, %6) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported
%8 = mul(%7, %0) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported
%9 = astype(%8, dtype=int32) # EncryptedScalar<int5>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype is not supported without fusing
return %9
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1084,13 +1092,17 @@ return(%9)
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
"return(%2)\n"
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
%1 = [[1 1 1] [1 1 1]] # ClearTensor<uint1, shape=(2, 3)>
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(3, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being
return %2
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1098,13 +1110,15 @@ return(%9)
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
"return(%2)\n"
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
%1 = [[1 1 1] [1 1 1]] # ClearTensor<uint1, shape=(2, 3)>
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(3, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being
return %2
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1112,13 +1126,17 @@ return(%9)
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"function you are trying to compile isn't supported for MLIR lowering\n"
"\n"
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
"return(%2)\n"
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
%1 = [[1 1 1] [1 1 1]] # ClearTensor<uint1, shape=(2, 3)>
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(3, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being
return %2
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1127,13 +1145,15 @@ return(%9)
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(32)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 2)>
%1 = MultiTLU(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being
return(%1)
""".lstrip() # noqa: E501
%0 = x # EncryptedTensor<uint2, shape=(3, 2)>
%1 = MultiTLU(%0, input_shape=(3, 2), tables=[[[1, 2, 1 ... 1, 2, 0]]]) # EncryptedTensor<uint2, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1141,13 +1161,16 @@ return(%1)
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"""function you are trying to compile isn't supported for MLIR lowering
"""
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>
%1 = np.transpose(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.transpose of kind Memory is not supported for the time being
return(%1)
""" # noqa: E501
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
%1 = transpose(%0) # EncryptedTensor<uint3, shape=(2, 3)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ transpose is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1155,13 +1178,16 @@ return(%1)
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
(
"""function you are trying to compile isn't supported for MLIR lowering
"""
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>
%1 = np.ravel(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(6,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.ravel of kind Memory is not supported for the time being
return(%1)
""" # noqa: E501
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
%1 = ravel(%0) # EncryptedTensor<uint3, shape=(6,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ravel is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
pytest.param(
@@ -1169,13 +1195,16 @@ return(%1)
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 4))},
[(numpy.random.randint(0, 2 ** 3, size=(3, 4)),) for i in range(10)],
(
"""function you are trying to compile isn't supported for MLIR lowering
"""
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 4)>
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 6)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.reshape of kind Memory is not supported for the time being
return(%1)
""" # noqa: E501
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 4)>
%1 = reshape(%0, newshape=(2, 6)) # EncryptedTensor<uint3, shape=(2, 6)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ reshape is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
],
@@ -1217,19 +1246,21 @@ def test_fail_with_intermediate_signed_values(default_compilation_configuration)
)
except RuntimeError as error:
match = """
function you are trying to compile isn't supported for MLIR lowering
%0 = y # EncryptedScalar<Integer<unsigned, 2 bits>>
%1 = Constant(10) # ClearScalar<Integer<unsigned, 4 bits>>
%2 = x # EncryptedScalar<Integer<unsigned, 2 bits>>
%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>
%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>
%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.absolute is not supported for the time being
%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>
%7 = Add(%6, %0) # EncryptedScalar<Integer<unsigned, 6 bits>>
return(%7)
""".lstrip() # noqa: E501 # pylint: disable=line-too-long
%0 = y # EncryptedScalar<uint2>
%1 = 10 # ClearScalar<uint4>
%2 = x # EncryptedScalar<uint2>
%3 = negative(%2) # EncryptedScalar<int3>
%4 = mul(%3, %1) # EncryptedScalar<int6>
%5 = absolute(%4) # EncryptedScalar<uint5>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ absolute is not supported for the time being
%6 = astype(%5, dtype=int32) # EncryptedScalar<uint5>
%7 = add(%6, %0) # EncryptedScalar<uint6>
return %7
""".strip() # noqa: E501 # pylint: disable=line-too-long
assert str(error) == match
raise
@@ -1270,13 +1301,14 @@ def test_small_inputset_treat_warnings_as_errors():
(4,),
# Remark that, when you do the dot of tensors of 4 values between 0 and 3,
# you can get a maximal value of 4*3*3 = 36, ie something on 6 bits
"%0 = x "
"# EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)>"
"\n%1 = y "
"# EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)>"
"\n%2 = Dot(%0, %1) "
"# EncryptedScalar<Integer<unsigned, 6 bits>>"
"\nreturn(%2)\n",
"""
%0 = x # EncryptedTensor<uint2, shape=(4,)>
%1 = y # EncryptedTensor<uint2, shape=(4,)>
%2 = dot(%0, %1) # EncryptedScalar<uint6>
return %2
""".strip(),
),
],
)
@@ -1303,7 +1335,7 @@ def test_compile_function_with_dot(
data_gen(max_for_ij, repeat),
default_compilation_configuration,
)
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
str_of_the_graph = format_operation_graph(op_graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
@@ -1373,13 +1405,16 @@ def test_compile_too_high_bitwidth(default_compilation_configuration):
assert (
str(excinfo.value)
== """
max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 7) which is not compatible with:
%0 = x # EncryptedScalar<Integer<unsigned, 7 bits>>
%1 = y # EncryptedScalar<Integer<unsigned, 5 bits>>
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 8 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 bits is not supported for the time being
return(%2)
""".lstrip() # noqa: E501 # pylint: disable=line-too-long
%0 = x # EncryptedScalar<uint7>
%1 = y # EncryptedScalar<uint5>
%2 = add(%0, %1) # EncryptedScalar<uint8>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 bits is not supported for the time being
return %2
""".strip() # noqa: E501 # pylint: disable=line-too-long
)
# Just ok
@@ -1418,14 +1453,16 @@ def test_failure_for_signed_output(default_compilation_configuration):
assert (
str(excinfo.value)
== """
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
%1 = Constant(-3) # ClearScalar<Integer<signed, 3 bits>>
%2 = Add(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
return(%2)
""".lstrip() # noqa: E501 # pylint: disable=line-too-long
%0 = x # EncryptedScalar<uint4>
%1 = -3 # ClearScalar<int3>
%2 = add(%0, %1) # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
return %2
""".strip() # noqa: E501 # pylint: disable=line-too-long
)

View File

@@ -368,8 +368,7 @@ def test_constant_indexing(
EncryptedScalar(UnsignedInteger(1)),
lambda x: x[0],
TypeError,
"Only tensors can be indexed "
"but you tried to index EncryptedScalar<Integer<unsigned, 1 bits>>",
"Only tensors can be indexed but you tried to index EncryptedScalar<uint1>",
),
pytest.param(
EncryptedTensor(UnsignedInteger(1), shape=(3,)),

View File

@@ -1,449 +0,0 @@
"""Test file for debugging functions"""
import numpy
import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import draw_graph, format_operation_graph
from concrete.common.extensions.table import LookupTable
from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11])
LOOKUP_TABLE_FROM_3B_TO_2B = LookupTable([0, 1, 3, 2, 2, 3, 1, 0])
def issue_130_a(x, y):
"""Test case derived from issue #130"""
# pylint: disable=unused-argument
intermediate = x + 1
return (intermediate, intermediate)
# pylint: enable=unused-argument
def issue_130_b(x, y):
"""Test case derived from issue #130"""
# pylint: disable=unused-argument
intermediate = x - 1
return (intermediate, intermediate)
# pylint: enable=unused-argument
def issue_130_c(x, y):
"""Test case derived from issue #130"""
# pylint: disable=unused-argument
intermediate = 1 - x
return (intermediate, intermediate)
# pylint: enable=unused-argument
@pytest.mark.parametrize(
"lambda_f,ref_graph_str",
[
(lambda x, y: x + y, "%0 = x\n%1 = y\n%2 = Add(%0, %1)\nreturn(%2)\n"),
(lambda x, y: x - y, "%0 = x\n%1 = y\n%2 = Sub(%0, %1)\nreturn(%2)\n"),
(lambda x, y: x + x, "%0 = x\n%1 = Add(%0, %0)\nreturn(%1)\n"),
(
lambda x, y: x + x - y * y * y + x,
"%0 = x\n%1 = y\n%2 = Add(%0, %0)\n%3 = Mul(%1, %1)"
"\n%4 = Mul(%3, %1)\n%5 = Sub(%2, %4)\n%6 = Add(%5, %0)\nreturn(%6)\n",
),
(lambda x, y: x + 1, "%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%2)\n"),
(lambda x, y: 1 + x, "%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%2)\n"),
(lambda x, y: (-1) + x, "%0 = x\n%1 = Constant(-1)\n%2 = Add(%0, %1)\nreturn(%2)\n"),
(lambda x, y: 3 * x, "%0 = x\n%1 = Constant(3)\n%2 = Mul(%0, %1)\nreturn(%2)\n"),
(lambda x, y: x * 3, "%0 = x\n%1 = Constant(3)\n%2 = Mul(%0, %1)\nreturn(%2)\n"),
(lambda x, y: x * (-3), "%0 = x\n%1 = Constant(-3)\n%2 = Mul(%0, %1)\nreturn(%2)\n"),
(lambda x, y: x - 11, "%0 = x\n%1 = Constant(11)\n%2 = Sub(%0, %1)\nreturn(%2)\n"),
(lambda x, y: 11 - x, "%0 = Constant(11)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2)\n"),
(lambda x, y: (-11) - x, "%0 = Constant(-11)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2)\n"),
(
lambda x, y: x + 13 - y * (-21) * y + 44,
"%0 = Constant(44)"
"\n%1 = x"
"\n%2 = Constant(13)"
"\n%3 = y"
"\n%4 = Constant(-21)"
"\n%5 = Add(%1, %2)"
"\n%6 = Mul(%3, %4)"
"\n%7 = Mul(%6, %3)"
"\n%8 = Sub(%5, %7)"
"\n%9 = Add(%8, %0)"
"\nreturn(%9)\n",
),
# Multiple outputs
(
lambda x, y: (x + 1, x + y + 2),
"%0 = x"
"\n%1 = Constant(1)"
"\n%2 = Constant(2)"
"\n%3 = y"
"\n%4 = Add(%0, %1)"
"\n%5 = Add(%0, %3)"
"\n%6 = Add(%5, %2)"
"\nreturn(%4, %6)\n",
),
(
lambda x, y: (y, x),
"%0 = y\n%1 = x\nreturn(%0, %1)\n",
),
(
lambda x, y: (x, x + 1),
"%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%0, %2)\n",
),
(
lambda x, y: (x + 1, x + 1),
"%0 = x"
"\n%1 = Constant(1)"
"\n%2 = Constant(1)"
"\n%3 = Add(%0, %1)"
"\n%4 = Add(%0, %2)"
"\nreturn(%3, %4)\n",
),
(
issue_130_a,
"%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%2, %2)\n",
),
(
issue_130_b,
"%0 = x\n%1 = Constant(1)\n%2 = Sub(%0, %1)\nreturn(%2, %2)\n",
),
(
issue_130_c,
"%0 = Constant(1)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2, %2)\n",
),
(
lambda x, y: numpy.arctan2(x, 42) + y,
"""%0 = y
%1 = x
%2 = Constant(42)
%3 = np.arctan2(%1, %2)
%4 = Add(%3, %0)
return(%4)
""",
),
(
lambda x, y: numpy.arctan2(43, x) + y,
"""%0 = y
%1 = Constant(43)
%2 = x
%3 = np.arctan2(%1, %2)
%4 = Add(%3, %0)
return(%4)
""",
),
],
)
@pytest.mark.parametrize(
"x_y",
[
pytest.param(
(
EncryptedScalar(Integer(64, is_signed=False)),
EncryptedScalar(Integer(64, is_signed=False)),
),
id="Encrypted uint",
),
pytest.param(
(
EncryptedScalar(Integer(64, is_signed=False)),
ClearScalar(Integer(64, is_signed=False)),
),
id="Clear uint",
),
],
)
def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
"Test format_operation_graph and draw_graph"
x, y = x_y
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
draw_graph(graph, show=False)
str_of_the_graph = format_operation_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x\n%1 = TLU(%0)\nreturn(%1)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x\n%1 = Constant(4)\n%2 = Add(%0, %1)\n%3 = TLU(%2)\nreturn(%3)\n",
),
],
)
def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
"Test format_operation_graph and draw_graph on graphs with direct table lookup"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = format_operation_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
(
lambda x, y: numpy.dot(x, y),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
"y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
},
"%0 = x\n%1 = y\n%2 = Dot(%0, %1)\nreturn(%2)\n",
),
],
)
def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
"Test format_operation_graph and draw_graph on graphs with dot"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = format_operation_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
# pylint: disable=line-too-long
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
(
lambda x: numpy.transpose(x),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
},
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
%1 = np.transpose(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(5, 3)>
return(%1)
""".lstrip(), # noqa: E501
),
(
lambda x: numpy.ravel(x),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
},
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
%1 = np.ravel(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(15,)>
return(%1)
""".lstrip(), # noqa: E501
),
(
lambda x: numpy.reshape(x, (5, 3)),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
},
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(5, 3)>
return(%1)
""".lstrip(), # noqa: E501
),
(
lambda x: numpy.reshape(x, (170,)),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)),
},
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(17, 10)>
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(170,)>
return(%1)
""".lstrip(), # noqa: E501
),
(
lambda x: numpy.reshape(x, (170)),
{
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)),
},
"""
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(17, 10)>
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(170,)>
return(%1)
""".lstrip(), # noqa: E501
),
],
)
def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str):
"Test format_operation_graph and draw_graph on graphs with generic function"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
# pylint: enable=line-too-long
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
# returning 23b), since they are replaced later by the real bitwidths computed on the
# inputset
@pytest.mark.parametrize(
"lambda_f,x_y,ref_graph_str",
[
(
lambda x, y: x + y,
(
EncryptedScalar(Integer(64, is_signed=False)),
EncryptedScalar(Integer(32, is_signed=True)),
),
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 64 bits>>"
"\n%1 = y "
" # EncryptedScalar<Integer<signed, 32 bits>>"
"\n%2 = Add(%0, %1) "
" # EncryptedScalar<Integer<signed, 65 bits>>"
"\nreturn(%2)\n",
),
(
lambda x, y: x * y,
(
EncryptedScalar(Integer(17, is_signed=False)),
EncryptedScalar(Integer(23, is_signed=False)),
),
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 17 bits>>"
"\n%1 = y "
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
"\n%2 = Mul(%0, %1) "
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
"\nreturn(%2)\n",
),
],
)
def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
"""Test format_operation_graph with show_data_types"""
x, y = x_y
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%1 = TLU(%0) "
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
"\nreturn(%1)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%1 = Constant(4) "
"# ClearScalar<Integer<unsigned, 3 bits>>"
"\n%2 = Add(%0, %1) "
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
"\n%3 = TLU(%2) "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\nreturn(%3)\n",
),
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]],
{"x": EncryptedScalar(Integer(2, is_signed=False))},
"%0 = x "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%1 = Constant(4) "
"# ClearScalar<Integer<unsigned, 3 bits>>"
"\n%2 = Add(%0, %1) "
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
"\n%3 = TLU(%2) "
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
"\n%4 = TLU(%3) "
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
"\nreturn(%4)\n",
),
],
)
def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str):
"""Test format_operation_graph with show_data_types on graphs with direct table lookup"""
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"
f"==================\n"
)
def test_numpy_long_constant():
"Test format_operation_graph with long constant"
def all_explicit_operations(x):
intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10))
intermediate = numpy.subtract(intermediate, numpy.arange(10).reshape(1, 10))
intermediate = numpy.arctan2(numpy.arange(10, 20).reshape(1, 10), intermediate)
intermediate = numpy.arctan2(numpy.arange(100, 200).reshape(10, 10), intermediate)
return intermediate
op_graph = tracing.trace_numpy_function(
all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(10, 10))}
)
expected = """
%0 = Constant([[100 101 ... 198 199]]) # ClearTensor<Integer<unsigned, 8 bits>, shape=(10, 10)>
%1 = Constant([[10 11 12 ... 17 18 19]]) # ClearTensor<Integer<unsigned, 5 bits>, shape=(1, 10)>
%2 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
%3 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%4 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
%5 = Add(%3, %4) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%6 = Sub(%5, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
%7 = np.arctan2(%1, %6) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
%8 = np.arctan2(%0, %7) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
return(%8)
""".lstrip() # noqa: E501
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
assert str_of_the_graph == expected, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{expected}"
f"==================\n"
)

View File

@@ -188,24 +188,22 @@ def test_numpy_tracing_tensors():
all_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
)
expected = """
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
%7 = Add(%5, %6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%8 = Add(%4, %7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%9 = Sub(%3, %8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%10 = Sub(%9, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%11 = Mul(%10, %1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%12 = Mul(%0, %11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
return(%12)
""".lstrip() # noqa: E501
expected = """ %0 = [[2 1] [1 2]] # ClearTensor<uint2, shape=(2, 2)>
%1 = [[1 2] [2 1]] # ClearTensor<uint2, shape=(2, 2)>
%2 = [[10 20] [30 40]] # ClearTensor<uint6, shape=(2, 2)>
%3 = [[100 200] [300 400]] # ClearTensor<uint9, shape=(2, 2)>
%4 = [[5 6] [7 8]] # ClearTensor<uint4, shape=(2, 2)>
%5 = x # EncryptedTensor<int32, shape=(2, 2)>
%6 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
%7 = add(%5, %6) # EncryptedTensor<int32, shape=(2, 2)>
%8 = add(%4, %7) # EncryptedTensor<int32, shape=(2, 2)>
%9 = sub(%3, %8) # EncryptedTensor<int32, shape=(2, 2)>
%10 = sub(%9, %2) # EncryptedTensor<int32, shape=(2, 2)>
%11 = mul(%10, %1) # EncryptedTensor<int32, shape=(2, 2)>
%12 = mul(%0, %11) # EncryptedTensor<int32, shape=(2, 2)>
return %12""" # noqa: E501
assert format_operation_graph(op_graph, show_data_types=True) == expected
assert format_operation_graph(op_graph) == expected, format_operation_graph(op_graph)
def test_numpy_explicit_tracing_tensors():
@@ -227,24 +225,22 @@ def test_numpy_explicit_tracing_tensors():
all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
)
expected = """
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
%7 = Add(%5, %6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%8 = Add(%4, %7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%9 = Sub(%3, %8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%10 = Sub(%9, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%11 = Mul(%10, %1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
%12 = Mul(%0, %11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
return(%12)
""".lstrip() # noqa: E501
expected = """ %0 = [[2 1] [1 2]] # ClearTensor<uint2, shape=(2, 2)>
%1 = [[1 2] [2 1]] # ClearTensor<uint2, shape=(2, 2)>
%2 = [[10 20] [30 40]] # ClearTensor<uint6, shape=(2, 2)>
%3 = [[100 200] [300 400]] # ClearTensor<uint9, shape=(2, 2)>
%4 = [[5 6] [7 8]] # ClearTensor<uint4, shape=(2, 2)>
%5 = x # EncryptedTensor<int32, shape=(2, 2)>
%6 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
%7 = add(%5, %6) # EncryptedTensor<int32, shape=(2, 2)>
%8 = add(%4, %7) # EncryptedTensor<int32, shape=(2, 2)>
%9 = sub(%3, %8) # EncryptedTensor<int32, shape=(2, 2)>
%10 = sub(%9, %2) # EncryptedTensor<int32, shape=(2, 2)>
%11 = mul(%10, %1) # EncryptedTensor<int32, shape=(2, 2)>
%12 = mul(%0, %11) # EncryptedTensor<int32, shape=(2, 2)>
return %12""" # noqa: E501
assert format_operation_graph(op_graph, show_data_types=True) == expected
assert format_operation_graph(op_graph) == expected
@pytest.mark.parametrize(