mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
refactor(debugging): re-write graph formatting
This commit is contained in:
@@ -85,7 +85,7 @@ class CompilationArtifacts:
|
||||
"""
|
||||
|
||||
drawing = draw_graph(operation_graph)
|
||||
textual_representation = format_operation_graph(operation_graph, show_data_types=True)
|
||||
textual_representation = format_operation_graph(operation_graph)
|
||||
|
||||
self.drawings_of_operation_graphs[name] = drawing
|
||||
self.textual_representations_of_operation_graphs[name] = textual_representation
|
||||
|
||||
@@ -21,6 +21,9 @@ class Float(base.BaseDataType):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}<{self.bit_width} bits>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"float{self.bit_width}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, self.__class__) and self.bit_width == other.bit_width
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ class Integer(base.BaseDataType):
|
||||
signed_str = "signed" if self.is_signed else "unsigned"
|
||||
return f"{self.__class__.__name__}<{signed_str}, {self.bit_width} bits>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{('int' if self.is_signed else 'uint')}{self.bit_width}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Module for debugging."""
|
||||
from .custom_assert import assert_true
|
||||
from .drawing import draw_graph
|
||||
from .printing import format_operation_graph
|
||||
from .formatting import format_operation_graph
|
||||
|
||||
@@ -84,7 +84,7 @@ def draw_graph(
|
||||
|
||||
attributes = {
|
||||
node: {
|
||||
"label": node.label(),
|
||||
"label": node.text_for_drawing(),
|
||||
"color": get_color(node, output_nodes),
|
||||
"penwidth": 2, # double thickness for circles
|
||||
"peripheries": 2 if node in output_nodes else 1, # double circle for output nodes
|
||||
|
||||
124
concrete/common/debugging/formatting.py
Normal file
124
concrete/common/debugging/formatting.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Functions to format operation graphs for debugging purposes."""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import IntermediateNode
|
||||
|
||||
|
||||
def format_operation_graph(
|
||||
opgraph: OPGraph,
|
||||
maximum_constant_length: int = 25,
|
||||
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Format an operation graph.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph):
|
||||
the operation graph to format
|
||||
|
||||
maximum_constant_length (int):
|
||||
maximum length of the constant throughout the formatting
|
||||
|
||||
highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]] = None):
|
||||
the dict of nodes and their corresponding messages which will be highlighted
|
||||
|
||||
Returns:
|
||||
str: formatted operation graph
|
||||
"""
|
||||
assert_true(isinstance(opgraph, OPGraph))
|
||||
|
||||
# (node, output_index) -> identifier
|
||||
# e.g., id_map[(node1, 0)] = 2 and id_map[(node1, 1)] = 3
|
||||
# means line for node1 is in this form (%2, %3) = node1.format(...)
|
||||
id_map: Dict[Tuple[IntermediateNode, int], int] = {}
|
||||
|
||||
# lines that will be merged at the end
|
||||
lines: List[str] = []
|
||||
|
||||
# type information to add to each line (for alingment, this is done after lines are determined)
|
||||
type_informations: List[str] = []
|
||||
|
||||
# default highlighted nodes is empty
|
||||
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
|
||||
|
||||
# highlight information for lines, this is required because highlights are added to lines
|
||||
# after their type information is added and we only have line numbers, not nodes
|
||||
highlighted_lines: Dict[int, List[str]] = {}
|
||||
|
||||
for node in nx.topological_sort(opgraph.graph):
|
||||
# assign a unique id to outputs of node
|
||||
assert_true(len(node.outputs) > 0)
|
||||
for i in range(len(node.outputs)):
|
||||
id_map[(node, i)] = len(id_map)
|
||||
|
||||
# remember highlights of the node
|
||||
if node in highlighted_nodes:
|
||||
highlighted_lines[len(lines)] = highlighted_nodes[node]
|
||||
|
||||
# extract predecessors and their ids
|
||||
predecessors = []
|
||||
for predecessor, output_idx in opgraph.get_ordered_inputs_of(node):
|
||||
predecessors.append(f"%{id_map[(predecessor, output_idx)]}")
|
||||
|
||||
# start the build the line for the node
|
||||
line = ""
|
||||
|
||||
# add output information to the line
|
||||
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
|
||||
line += outputs if len(node.outputs) == 1 else f"({outputs})"
|
||||
|
||||
# add node information to the line
|
||||
line += " = "
|
||||
line += node.text_for_formatting(predecessors, maximum_constant_length)
|
||||
|
||||
# append line to list of lines
|
||||
lines.append(line)
|
||||
|
||||
# remember type information of the node
|
||||
types = ", ".join(str(output) for output in node.outputs)
|
||||
type_informations.append(types if len(node.outputs) == 1 else f"({types})")
|
||||
|
||||
# align = signs
|
||||
#
|
||||
# e.g.,
|
||||
#
|
||||
# %1 = ...
|
||||
# %2 = ...
|
||||
# ...
|
||||
# %8 = ...
|
||||
# %9 = ...
|
||||
# %10 = ...
|
||||
# %11 = ...
|
||||
# ...
|
||||
longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines)
|
||||
for i, line in enumerate(lines):
|
||||
length_before_equals_sign = len(line.split("=")[0])
|
||||
lines[i] = (" " * (longest_length_before_equals_sign - length_before_equals_sign)) + line
|
||||
|
||||
# add type informations
|
||||
longest_line_length = max(len(line) for line in lines)
|
||||
for i, line in enumerate(lines):
|
||||
lines[i] += " " * (longest_line_length - len(line))
|
||||
lines[i] += f" # {type_informations[i]}"
|
||||
|
||||
# add highlights (this is done in reverse to keep indices consistent)
|
||||
for i in reversed(range(len(lines))):
|
||||
if i in highlighted_lines:
|
||||
for j, message in enumerate(highlighted_lines[i]):
|
||||
highlight = "^" if j == 0 else " "
|
||||
lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}")
|
||||
|
||||
# add return information
|
||||
# (if there is a single return, it's in the form `return %id`
|
||||
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
|
||||
returns: List[str] = []
|
||||
for node in opgraph.output_nodes.values():
|
||||
outputs = ", ".join(f"%{id_map[(node, i)]}" for i in range(len(node.outputs)))
|
||||
returns.append(outputs if len(node.outputs) == 1 else f"({outputs})")
|
||||
lines.append("return " + returns[0] if len(returns) == 1 else f"({', '.join(returns)})")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -1,146 +0,0 @@
|
||||
"""functions to print the different graphs we can generate in the package, eg to debug."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import (
|
||||
Constant,
|
||||
GenericFunction,
|
||||
IndexConstant,
|
||||
Input,
|
||||
IntermediateNode,
|
||||
)
|
||||
|
||||
|
||||
def output_data_type_to_string(node):
|
||||
"""Return the datatypes of the outputs of the node.
|
||||
|
||||
Args:
|
||||
node: a graph node
|
||||
|
||||
Returns:
|
||||
str: a string representing the datatypes of the outputs of the node
|
||||
|
||||
"""
|
||||
return ", ".join([str(o) for o in node.outputs])
|
||||
|
||||
|
||||
def shorten_a_constant(constant_data: str):
|
||||
"""Return a constant (if small) or an extra of the constant (if too large).
|
||||
|
||||
Args:
|
||||
constant (str): The constant we want to shorten
|
||||
|
||||
Returns:
|
||||
str: a string to represent the constant
|
||||
"""
|
||||
|
||||
content = str(constant_data).replace("\n", "")
|
||||
# if content is longer than 25 chars, only show the first and the last 10 chars of it
|
||||
# 25 is selected using the spaces available before data type information
|
||||
short_content = f"{content[:10]} ... {content[-10:]}" if len(content) > 25 else content
|
||||
return short_content
|
||||
|
||||
|
||||
def format_operation_graph(
|
||||
opgraph: OPGraph,
|
||||
show_data_types: bool = False,
|
||||
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Return a string representing a graph.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph): The graph that we want to draw
|
||||
show_data_types (bool, optional): Whether or not showing data_types of nodes, eg to see
|
||||
their width. Defaults to False.
|
||||
highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]], optional): The dict of nodes
|
||||
which will be highlighted and their corresponding messages. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: a string to print or save in a file
|
||||
"""
|
||||
assert_true(isinstance(opgraph, OPGraph))
|
||||
|
||||
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
|
||||
|
||||
list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values())
|
||||
graph = opgraph.graph
|
||||
|
||||
returned_str = ""
|
||||
|
||||
i = 0
|
||||
map_table: Dict[Any, int] = {}
|
||||
|
||||
for node in nx.topological_sort(graph):
|
||||
|
||||
# TODO: #640
|
||||
# This code doesn't work with more than a single output. For more outputs,
|
||||
# we would need to change the way the destination are created: currently,
|
||||
# they only are done by incrementing i
|
||||
assert_true(len(node.outputs) == 1)
|
||||
|
||||
if isinstance(node, Input):
|
||||
what_to_print = node.input_name
|
||||
elif isinstance(node, Constant):
|
||||
to_show = shorten_a_constant(node.constant_data)
|
||||
what_to_print = f"Constant({to_show})"
|
||||
else:
|
||||
|
||||
base_name = node.__class__.__name__
|
||||
|
||||
if isinstance(node, GenericFunction):
|
||||
base_name = node.op_name
|
||||
|
||||
what_to_print = base_name + "("
|
||||
|
||||
# Find all the names of the current predecessors of the node
|
||||
list_of_arg_name = []
|
||||
|
||||
for pred, index_list in graph.pred[node].items():
|
||||
for index in index_list.values():
|
||||
# Remark that we keep the index of the predecessor and its
|
||||
# name, to print sources in the right order, which is
|
||||
# important for eg non commutative operations
|
||||
list_of_arg_name += [(index["input_idx"], str(map_table[pred]))]
|
||||
|
||||
# Some checks, because the previous algorithm is not clear
|
||||
assert_true(len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name)))
|
||||
list_of_arg_name.sort()
|
||||
assert_true([x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name))))
|
||||
|
||||
prefix_to_add_to_what_to_print = ""
|
||||
suffix_to_add_to_what_to_print = ""
|
||||
|
||||
# Then, just print the predecessors in the right order
|
||||
what_to_print += prefix_to_add_to_what_to_print
|
||||
what_to_print += ", ".join(["%" + x[1] for x in list_of_arg_name])
|
||||
what_to_print += suffix_to_add_to_what_to_print
|
||||
what_to_print += (
|
||||
f"{node.label().replace('value', '')}" if isinstance(node, IndexConstant) else ""
|
||||
)
|
||||
what_to_print += ")"
|
||||
|
||||
# This code doesn't work with more than a single output
|
||||
new_line = f"%{i} = {what_to_print}"
|
||||
|
||||
# Manage datatypes
|
||||
if show_data_types:
|
||||
new_line = f"{new_line: <50s} # {output_data_type_to_string(node)}"
|
||||
|
||||
returned_str += f"{new_line}\n"
|
||||
|
||||
if node in highlighted_nodes:
|
||||
new_line_len = len(new_line)
|
||||
message = f"\n{' ' * new_line_len} ".join(highlighted_nodes[node])
|
||||
returned_str += f"{'^' * new_line_len} {message}\n"
|
||||
|
||||
map_table[node] = i
|
||||
i += 1
|
||||
|
||||
return_part = ", ".join(["%" + str(map_table[n]) for n in list_of_nodes_which_are_outputs])
|
||||
returned_str += f"return({return_part})\n"
|
||||
|
||||
return returned_str
|
||||
@@ -30,6 +30,9 @@ class LookupTable:
|
||||
self.table = table
|
||||
self.output_dtype = make_integer_to_hold(table, force_signed=False)
|
||||
|
||||
def __repr__(self):
|
||||
return str(list(self.table))
|
||||
|
||||
def __getitem__(self, key: Union[int, Iterable, BaseTracer]):
|
||||
# if a tracer is used for indexing,
|
||||
# we need to create an `GenericFunction` node
|
||||
|
||||
@@ -21,7 +21,7 @@ class FHECircuit:
|
||||
self.engine = engine
|
||||
|
||||
def __str__(self):
|
||||
return format_operation_graph(self.opgraph, show_data_types=True)
|
||||
return format_operation_graph(self.opgraph)
|
||||
|
||||
def draw(
|
||||
self,
|
||||
|
||||
47
concrete/common/helpers/formatting_helpers.py
Normal file
47
concrete/common/helpers/formatting_helpers.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Helpers for formatting functionality."""
|
||||
|
||||
from typing import Any, Dict, Hashable
|
||||
|
||||
import numpy
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
|
||||
SPECIAL_OBJECT_MAPPING: Dict[Any, str] = {
|
||||
numpy.float32: "float32",
|
||||
numpy.float64: "float64",
|
||||
numpy.int8: "int8",
|
||||
numpy.int16: "int16",
|
||||
numpy.int32: "int32",
|
||||
numpy.int64: "int64",
|
||||
numpy.uint8: "uint8",
|
||||
numpy.uint16: "uint16",
|
||||
numpy.uint32: "uint32",
|
||||
numpy.uint64: "uint64",
|
||||
}
|
||||
|
||||
|
||||
def format_constant(constant: Any, maximum_length: int = 45) -> str:
|
||||
"""Format a constant.
|
||||
|
||||
Args:
|
||||
constant (Any): the constant to format
|
||||
maximum_length (int): maximum length of the resulting string
|
||||
|
||||
Returns:
|
||||
str: the formatted constant
|
||||
"""
|
||||
|
||||
if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING:
|
||||
return SPECIAL_OBJECT_MAPPING[constant]
|
||||
|
||||
# maximum_length should not be smaller than 7 characters because
|
||||
# the constant will be formatted to `x ... y`
|
||||
# where x and y are part of the constant and they are at least 1 character
|
||||
assert_true(maximum_length >= 7)
|
||||
|
||||
content = str(constant).replace("\n", "")
|
||||
if len(content) > maximum_length:
|
||||
from_start = (maximum_length - 5) // 2
|
||||
from_end = (maximum_length - 5) - from_start
|
||||
content = f"{content[:from_start]} ... {content[-from_end:]}"
|
||||
return content
|
||||
@@ -96,9 +96,7 @@ def check_node_compatibility_with_mlir(
|
||||
# e.g., `np.absolute is not supported for the time being`
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
else:
|
||||
return (
|
||||
f"{node.op_name} of kind {node.op_kind.value} is not supported for the time being"
|
||||
)
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
|
||||
elif isinstance(node, intermediate.Dot): # constraints for dot product
|
||||
assert_true(len(inputs) == 2)
|
||||
@@ -206,10 +204,8 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
raise RuntimeError(
|
||||
f"max_bit_width of some nodes is too high for the current version of "
|
||||
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) "
|
||||
f"which is not compatible with:\n"
|
||||
+ format_operation_graph(
|
||||
op_graph, show_data_types=True, highlighted_nodes=offending_nodes
|
||||
)
|
||||
f"which is not compatible with:\n\n"
|
||||
+ format_operation_graph(op_graph, highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
|
||||
@@ -131,6 +131,24 @@ class OPGraph:
|
||||
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
|
||||
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
|
||||
|
||||
def get_ordered_inputs_of(self, node: IntermediateNode) -> List[Tuple[IntermediateNode, int]]:
|
||||
"""Get node inputs ordered by their indices.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): the node for which we want the ordered inputs
|
||||
|
||||
Returns:
|
||||
List[Tuple[IntermediateNode, int]]: the ordered list of inputs
|
||||
"""
|
||||
|
||||
idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {}
|
||||
for pred in self.graph.pred[node]:
|
||||
edge_data = self.graph.get_edge_data(pred, node)
|
||||
idx_to_inp.update(
|
||||
(data["input_idx"], (pred, data["output_idx"])) for data in edge_data.values()
|
||||
)
|
||||
return [idx_to_inp[i] for i in range(len(idx_to_inp))]
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]:
|
||||
"""Evaluate a graph and get intermediate values for all nodes.
|
||||
|
||||
|
||||
@@ -133,10 +133,9 @@ def convert_float_subgraph_to_fused_node(
|
||||
|
||||
printable_graph = format_operation_graph(
|
||||
float_subgraph_as_op_graph,
|
||||
show_data_types=True,
|
||||
highlighted_nodes=node_with_issues_for_fusing,
|
||||
)
|
||||
message = f"The following subgraph is not fusable:\n{printable_graph}"
|
||||
message = f"The following subgraph is not fusable:\n\n{printable_graph}"
|
||||
logger.warning(message)
|
||||
return None
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..data_types.dtypes_helpers import (
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..helpers import indexing_helpers
|
||||
from ..helpers.formatting_helpers import format_constant
|
||||
from ..helpers.python_helpers import catch, update_and_return_dict
|
||||
from ..values import (
|
||||
BaseValue,
|
||||
@@ -64,6 +65,27 @@ class IntermediateNode(ABC):
|
||||
|
||||
self.outputs = [mix_values_func(self.inputs[0], self.inputs[1])]
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
|
||||
"""Get the formatted node (used in formatting opgraph).
|
||||
|
||||
Args:
|
||||
predecessors (List[str]): predecessor names to this node
|
||||
_maximum_constant_length (int): desired maximum constant length
|
||||
|
||||
Returns:
|
||||
str: the formatted node
|
||||
"""
|
||||
|
||||
return f"{self.__class__.__name__.lower()}({', '.join(predecessors)})"
|
||||
|
||||
@abstractmethod
|
||||
def text_for_drawing(self) -> str:
|
||||
"""Get the label of the node (used in drawing opgraph).
|
||||
|
||||
Returns:
|
||||
str: the label of the node
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
"""Simulate what the represented computation would output for the given inputs.
|
||||
@@ -93,15 +115,6 @@ class IntermediateNode(ABC):
|
||||
"""
|
||||
return cls.n_in() > 1
|
||||
|
||||
@abstractmethod
|
||||
def label(self) -> str:
|
||||
"""Get the label of the node.
|
||||
|
||||
Returns:
|
||||
str: the label of the node
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class Add(IntermediateNode):
|
||||
"""Addition between two values."""
|
||||
@@ -110,12 +123,12 @@ class Add(IntermediateNode):
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "+"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] + inputs[1]
|
||||
|
||||
def label(self) -> str:
|
||||
return "+"
|
||||
|
||||
|
||||
class Sub(IntermediateNode):
|
||||
"""Subtraction between two values."""
|
||||
@@ -124,12 +137,12 @@ class Sub(IntermediateNode):
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "-"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] - inputs[1]
|
||||
|
||||
def label(self) -> str:
|
||||
return "-"
|
||||
|
||||
|
||||
class Mul(IntermediateNode):
|
||||
"""Multiplication between two values."""
|
||||
@@ -138,12 +151,12 @@ class Mul(IntermediateNode):
|
||||
|
||||
__init__ = IntermediateNode._init_binary
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "*"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] * inputs[1]
|
||||
|
||||
def label(self) -> str:
|
||||
return "*"
|
||||
|
||||
|
||||
class Input(IntermediateNode):
|
||||
"""Node representing an input of the program."""
|
||||
@@ -164,12 +177,16 @@ class Input(IntermediateNode):
|
||||
self.program_input_idx = program_input_idx
|
||||
self.outputs = [deepcopy(self.inputs[0])]
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
|
||||
assert_true(len(predecessors) == 0)
|
||||
return self.input_name
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return self.input_name
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0]
|
||||
|
||||
def label(self) -> str:
|
||||
return self.input_name
|
||||
|
||||
|
||||
class Constant(IntermediateNode):
|
||||
"""Node representing a constant of the program."""
|
||||
@@ -191,6 +208,13 @@ class Constant(IntermediateNode):
|
||||
self._constant_data = constant_data
|
||||
self.outputs = [base_value_class(is_encrypted=False)]
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str:
|
||||
assert_true(len(predecessors) == 0)
|
||||
return format_constant(self.constant_data, maximum_constant_length)
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return format_constant(self.constant_data)
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return self.constant_data
|
||||
|
||||
@@ -203,9 +227,6 @@ class Constant(IntermediateNode):
|
||||
"""
|
||||
return self._constant_data
|
||||
|
||||
def label(self) -> str:
|
||||
return str(self.constant_data)
|
||||
|
||||
|
||||
class IndexConstant(IntermediateNode):
|
||||
"""Node representing a constant indexing in the program.
|
||||
@@ -242,19 +263,17 @@ class IndexConstant(IntermediateNode):
|
||||
else ClearTensor(output_dtype, output_shape)
|
||||
]
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0][self.index]
|
||||
|
||||
def label(self) -> str:
|
||||
"""Label of the node to show during drawings.
|
||||
|
||||
It can be used for some other places after `"value"` below is replaced by `""`.
|
||||
This note will no longer be necessary after #707 is addressed.
|
||||
|
||||
"""
|
||||
def text_for_formatting(self, predecessors: List[str], _maximum_constant_length: int) -> str:
|
||||
assert_true(len(predecessors) == 1)
|
||||
elements = [indexing_helpers.format_indexing_element(element) for element in self.index]
|
||||
index = ", ".join(elements)
|
||||
return f"value[{index}]"
|
||||
return f"{predecessors[0]}[{index}]"
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return self.text_for_formatting(["value"], 0) # 0 is unused
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0][self.index]
|
||||
|
||||
|
||||
def flood_replace_none_values(table: list):
|
||||
@@ -335,15 +354,26 @@ class GenericFunction(IntermediateNode):
|
||||
|
||||
self.op_name = op_name if op_name is not None else self.__class__.__name__
|
||||
|
||||
def text_for_formatting(self, predecessors: List[str], maximum_constant_length: int) -> str:
|
||||
all_args = deepcopy(predecessors)
|
||||
|
||||
all_args.extend(format_constant(value, maximum_constant_length) for value in self.op_args)
|
||||
all_args.extend(
|
||||
f"{name}={format_constant(value, maximum_constant_length)}"
|
||||
for name, value in self.op_kwargs.items()
|
||||
)
|
||||
|
||||
return f"{self.op_name}({', '.join(all_args)})"
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return self.op_name
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
# This is the continuation of the mypy bug workaround
|
||||
assert self.arbitrary_func is not None
|
||||
ordered_inputs = [inputs[idx] for idx in range(len(inputs))]
|
||||
return self.arbitrary_func(*ordered_inputs, *self.op_args, **self.op_kwargs)
|
||||
|
||||
def label(self) -> str:
|
||||
return self.op_name
|
||||
|
||||
def get_table(self, ordered_preds: List[IntermediateNode]) -> List[Any]:
|
||||
"""Get the table for the current input value of this GenericFunction.
|
||||
|
||||
@@ -466,14 +496,14 @@ class Dot(IntermediateNode):
|
||||
self.outputs = [output_scalar_value(output_dtype)]
|
||||
self.evaluation_function = delegate_evaluation_function
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "dot"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
# This is the continuation of the mypy bug workaround
|
||||
assert self.evaluation_function is not None
|
||||
return self.evaluation_function(inputs[0], inputs[1])
|
||||
|
||||
def label(self) -> str:
|
||||
return "dot"
|
||||
|
||||
|
||||
class MatMul(IntermediateNode):
|
||||
"""Return the node representing a matrix multiplication."""
|
||||
@@ -513,8 +543,8 @@ class MatMul(IntermediateNode):
|
||||
|
||||
self.outputs = [output_value]
|
||||
|
||||
def text_for_drawing(self) -> str:
|
||||
return "matmul"
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] @ inputs[1]
|
||||
|
||||
def label(self) -> str:
|
||||
return "@"
|
||||
|
||||
@@ -262,9 +262,7 @@ def prepare_op_graph_for_mlir(op_graph):
|
||||
if offending_nodes is not None:
|
||||
raise RuntimeError(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n\n"
|
||||
+ format_operation_graph(
|
||||
op_graph, show_data_types=True, highlighted_nodes=offending_nodes
|
||||
)
|
||||
+ format_operation_graph(op_graph, highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
# Update bit_width for MLIR
|
||||
|
||||
@@ -104,7 +104,7 @@ class NPTracer(BaseTracer):
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_kwargs={"dtype": normalized_numpy_dtype.type},
|
||||
op_name=f"astype({normalized_numpy_dtype})",
|
||||
op_name="astype",
|
||||
)
|
||||
output_tracer = self.__class__([self], traced_computation=traced_computation, output_idx=0)
|
||||
return output_tracer
|
||||
@@ -320,7 +320,7 @@ class NPTracer(BaseTracer):
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="np.transpose",
|
||||
op_name="transpose",
|
||||
op_attributes={"fusable": transpose_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
@@ -367,7 +367,7 @@ class NPTracer(BaseTracer):
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="np.ravel",
|
||||
op_name="ravel",
|
||||
op_attributes={"fusable": ravel_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
@@ -405,12 +405,9 @@ class NPTracer(BaseTracer):
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
newshape = deepcopy(arg1)
|
||||
|
||||
if isinstance(newshape, int):
|
||||
# Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work, while classical form is
|
||||
# numpy.reshape(x, (170,))
|
||||
newshape = (newshape,)
|
||||
# Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work,
|
||||
# while classical form is numpy.reshape(x, (170,))
|
||||
newshape = deepcopy(arg1) if not isinstance(arg1, int) else (arg1,)
|
||||
|
||||
# Check shape compatibility
|
||||
assert_true(
|
||||
@@ -435,7 +432,7 @@ class NPTracer(BaseTracer):
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs={"newshape": newshape},
|
||||
op_name="np.reshape",
|
||||
op_name="reshape",
|
||||
op_attributes={"fusable": reshape_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
@@ -575,7 +572,7 @@ def _get_unary_fun(function: numpy.ufunc):
|
||||
# dynamically
|
||||
# pylint: disable=protected-access
|
||||
return lambda *input_tracers, **kwargs: NPTracer._unary_operator(
|
||||
function, f"np.{function.__name__}", *input_tracers, **kwargs
|
||||
function, f"{function.__name__}", *input_tracers, **kwargs
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@@ -587,7 +584,7 @@ def _get_binary_fun(function: numpy.ufunc):
|
||||
# dynamically
|
||||
# pylint: disable=protected-access
|
||||
return lambda *input_tracers, **kwargs: NPTracer._binary_operator(
|
||||
function, f"np.{function.__name__}", *input_tracers, **kwargs
|
||||
function, f"{function.__name__}", *input_tracers, **kwargs
|
||||
)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@@ -328,11 +328,11 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_default(capsys):
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
|
||||
@@ -365,14 +365,14 @@ def test_eval_op_graph_bounds_on_non_conformant_inputset_check_all(capsys):
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint3, shape=(3,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
|
||||
@@ -436,14 +436,14 @@ def test_eval_op_graph_bounds_on_non_conformant_numpy_inputset_check_all(capsys)
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
captured.err == "Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"(expected EncryptedTensor<uint2, shape=(3,)> for parameter `x` "
|
||||
"but got EncryptedTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #0 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 2 bits>, shape=(4,)> which is not compatible)\n"
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint2, shape=(4,)> which is not compatible)\n"
|
||||
"Warning: Input #1 (0-indexed) is not coherent with the hinted parameters "
|
||||
"(expected ClearTensor<Integer<unsigned, 2 bits>, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<Integer<unsigned, 3 bits>, shape=(3,)> which is not compatible)\n"
|
||||
"(expected ClearTensor<uint2, shape=(3,)> for parameter `y` "
|
||||
"but got ClearTensor<uint3, shape=(3,)> which is not compatible)\n"
|
||||
)
|
||||
|
||||
|
||||
|
||||
78
tests/common/debugging/test_formatting.py
Normal file
78
tests/common/debugging/test_formatting.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Test file for formatting"""
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def test_format_operation_graph_with_multiple_edges(default_compilation_configuration):
|
||||
"""Test format_operation_graph with multiple edges"""
|
||||
|
||||
def function(x):
|
||||
return x + x
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(4, True))},
|
||||
[(i,) for i in range(0, 10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
formatted_graph = format_operation_graph(opgraph)
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<uint4>
|
||||
%1 = add(%0, %0) # EncryptedScalar<uint5>
|
||||
return %1
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
|
||||
"""Test format_operation_graph with offending nodes"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
|
||||
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
return %2
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
|
||||
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
|
||||
baz
|
||||
return %2
|
||||
|
||||
""".strip()
|
||||
)
|
||||
@@ -1,94 +0,0 @@
|
||||
"""Test file for printing"""
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
|
||||
"""Test format_operation_graph with offending nodes"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
|
||||
|
||||
without_types = format_operation_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
with_types = format_operation_graph(
|
||||
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
|
||||
assert (
|
||||
without_types
|
||||
== """
|
||||
|
||||
%0 = x
|
||||
^^^^^^ foo
|
||||
%1 = Constant(42)
|
||||
%2 = Add(%0, %1)
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
with_types
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
|
||||
|
||||
without_types = format_operation_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
with_types = format_operation_graph(
|
||||
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
|
||||
assert (
|
||||
without_types
|
||||
== """
|
||||
|
||||
%0 = x
|
||||
^^^^^^ foo
|
||||
%1 = Constant(42)
|
||||
%2 = Add(%0, %1)
|
||||
^^^^^^^^^^^^^^^^ bar
|
||||
baz
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
with_types
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
|
||||
baz
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
@@ -157,88 +157,118 @@ def get_func_params_int32(func, scalar=True):
|
||||
no_fuse_unhandled,
|
||||
False,
|
||||
get_func_params_int32(no_fuse_unhandled),
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedScalar<Integer<signed, 32 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
|
||||
%1 = Constant(0.7) # ClearScalar<Float<64 bits>>
|
||||
%2 = y # EncryptedScalar<Integer<signed, 32 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
|
||||
%3 = Constant(1.3) # ClearScalar<Float<64 bits>>
|
||||
%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>
|
||||
%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>
|
||||
%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>
|
||||
%7 = astype(int32)(%6) # EncryptedScalar<Integer<signed, 32 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs
|
||||
return(%7)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedScalar<int32>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
|
||||
%1 = 0.7 # ClearScalar<float64>
|
||||
%2 = y # EncryptedScalar<int32>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of 2 variable inputs (can only have 1 for fusing)
|
||||
%3 = 1.3 # ClearScalar<float64>
|
||||
%4 = add(%0, %1) # EncryptedScalar<float64>
|
||||
%5 = add(%2, %3) # EncryptedScalar<float64>
|
||||
%6 = add(%4, %5) # EncryptedScalar<float64>
|
||||
%7 = astype(%6, dtype=int32) # EncryptedScalar<int32>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot fuse here as the subgraph has 2 variable inputs
|
||||
return %7
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_unhandled",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_dot,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10,))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,)
|
||||
%1 = Constant([1.33 1.33 ... 1.33 1.33]) # ClearTensor<Float<64 bits>, shape=(10,)>
|
||||
%2 = Dot(%0, %1) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
|
||||
%3 = astype(int32)(%2) # EncryptedScalar<Integer<signed, 32 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedTensor<int32, shape=(10,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10,)
|
||||
%1 = [1.33 1.33 ... 1.33 1.33] # ClearTensor<float64, shape=(10,)>
|
||||
%2 = dot(%0, %1) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
|
||||
%3 = astype(%2, dtype=int32) # EncryptedScalar<int32>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ output shapes: #0, () are not the same as the subgraph's input: (10,)
|
||||
return %3
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_dot",
|
||||
),
|
||||
pytest.param(
|
||||
ravel_cases,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
|
||||
%2 = np.ravel(%1) # EncryptedTensor<Float<64 bits>, shape=(200,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(200,)>
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
|
||||
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
|
||||
%2 = ravel(%1) # EncryptedTensor<float64, shape=(200,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(200,)>
|
||||
return %3
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_ravel",
|
||||
),
|
||||
pytest.param(
|
||||
transpose_cases,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
|
||||
%2 = np.transpose(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
|
||||
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
|
||||
%2 = transpose(%1) # EncryptedTensor<float64, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(20, 10)>
|
||||
return %3
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_transpose",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: reshape_cases(x, (20, 10)),
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 20))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 20)>
|
||||
%1 = astype(float64)(%0) # EncryptedTensor<Float<64 bits>, shape=(10, 20)>
|
||||
%2 = np.reshape(%1) # EncryptedTensor<Float<64 bits>, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(int32)(%2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(20, 10)>
|
||||
return(%3)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = x # EncryptedTensor<int32, shape=(10, 20)>
|
||||
%1 = astype(%0, dtype=float64) # EncryptedTensor<float64, shape=(10, 20)>
|
||||
%2 = reshape(%1, newshape=(20, 10)) # EncryptedTensor<float64, shape=(20, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this node is explicitely marked by the package as non-fusable
|
||||
%3 = astype(%2, dtype=int32) # EncryptedTensor<int32, shape=(20, 10)>
|
||||
return %3
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_explicitely_reshape",
|
||||
),
|
||||
pytest.param(
|
||||
no_fuse_big_constant_3_10_10,
|
||||
False,
|
||||
{"x": EncryptedTensor(Integer(32, True), (10, 10))},
|
||||
"""The following subgraph is not fusable:
|
||||
%0 = Constant([[[1. 1. 1 ... . 1. 1.]]]) # ClearTensor<Float<64 bits>, shape=(3, 10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
|
||||
%1 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
|
||||
%2 = astype(float64)(%1) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
%3 = Add(%2, %0) # EncryptedTensor<Float<64 bits>, shape=(3, 10, 10)>
|
||||
%4 = astype(int32)(%3) # EncryptedTensor<Integer<signed, 32 bits>, shape=(3, 10, 10)>
|
||||
return(%4)""", # noqa: E501 # pylint: disable=line-too-long
|
||||
"""
|
||||
|
||||
The following subgraph is not fusable:
|
||||
|
||||
%0 = [[[1. 1. 1 ... . 1. 1.]]] # ClearTensor<float64, shape=(3, 10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this constant node has a bigger shape (3, 10, 10) than the subgraph's input: (10, 10)
|
||||
%1 = x # EncryptedTensor<int32, shape=(10, 10)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input node with shape (10, 10)
|
||||
%2 = astype(%1, dtype=float64) # EncryptedTensor<float64, shape=(10, 10)>
|
||||
%3 = add(%2, %0) # EncryptedTensor<float64, shape=(3, 10, 10)>
|
||||
%4 = astype(%3, dtype=int32) # EncryptedTensor<int32, shape=(3, 10, 10)>
|
||||
return %4
|
||||
|
||||
""".strip(), # noqa: E501 # pylint: disable=line-too-long
|
||||
id="no_fuse_big_constant_3_10_10",
|
||||
),
|
||||
pytest.param(
|
||||
@@ -322,7 +352,7 @@ def test_fuse_float_operations(
|
||||
else:
|
||||
assert fused_num_nodes == orig_num_nodes
|
||||
captured = capfd.readouterr()
|
||||
assert warning_message in remove_color_codes(captured.err)
|
||||
assert warning_message in (output := remove_color_codes(captured.err)), output
|
||||
|
||||
for input_ in [0, 2, 42, 44]:
|
||||
inputs = ()
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_circuit_str(default_compilation_configuration):
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
assert str(circuit) == format_operation_graph(circuit.opgraph, show_data_types=True)
|
||||
assert str(circuit) == format_operation_graph(circuit.opgraph)
|
||||
|
||||
|
||||
def test_circuit_draw(default_compilation_configuration):
|
||||
|
||||
@@ -495,7 +495,7 @@ def test_compile_function_multiple_outputs(
|
||||
# when we have the converter, we can check the MLIR
|
||||
draw_graph(op_graph, show=False)
|
||||
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
@@ -960,7 +960,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration):
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
@@ -991,14 +991,16 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
[(i,) for i in range(8)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
|
||||
%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>
|
||||
%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
%0 = 1 # ClearScalar<uint1>
|
||||
%1 = x # EncryptedScalar<uint3>
|
||||
%2 = sub(%0, %1) # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1021,16 +1023,18 @@ return(%2)
|
||||
],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
|
||||
%1 = y # EncryptedTensor<Integer<signed, 2 bits>, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
|
||||
%2 = Dot(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
%0 = x # EncryptedTensor<int2, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
|
||||
%1 = y # EncryptedTensor<int2, shape=(1,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
|
||||
%2 = dot(%0, %1) # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1039,14 +1043,16 @@ return(%2)
|
||||
[(numpy.random.randint(-4, 2 ** 2, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<signed, 3 bits>, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
|
||||
%1 = IndexConstant(%0[0]) # EncryptedTensor<Integer<signed, 3 bits>, shape=(2,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ indexing is not supported for the time being
|
||||
return(%1)
|
||||
""".lstrip() # noqa: E501
|
||||
%0 = x # EncryptedTensor<int3, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
|
||||
%1 = %0[0] # EncryptedTensor<int3, shape=(2,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ indexing is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1055,28 +1061,30 @@ return(%1)
|
||||
[(numpy.array(i), numpy.array(i)) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = Constant(1.5) # ClearScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%1 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%2 = Constant(2.8) # ClearScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%3 = y # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%4 = Constant(9.3) # ClearScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%5 = Add(%1, %2) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
|
||||
%6 = Add(%3, %4) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
|
||||
%7 = Sub(%5, %6) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported
|
||||
%8 = Mul(%7, %0) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported
|
||||
%9 = astype(int32)(%8) # EncryptedScalar<Integer<signed, 5 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype(int32) is not supported without fusing
|
||||
return(%9)
|
||||
""".lstrip() # noqa: E501
|
||||
%0 = 1.5 # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%1 = x # EncryptedScalar<uint4>
|
||||
%2 = 2.8 # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%3 = y # EncryptedScalar<uint4>
|
||||
%4 = 9.3 # ClearScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%5 = add(%1, %2) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
|
||||
%6 = add(%3, %4) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
|
||||
%7 = sub(%5, %6) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported
|
||||
%8 = mul(%7, %0) # EncryptedScalar<float64>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported
|
||||
%9 = astype(%8, dtype=int32) # EncryptedScalar<int5>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype is not supported without fusing
|
||||
return %9
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1084,13 +1092,17 @@ return(%9)
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
|
||||
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
|
||||
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = [[1 1 1] [1 1 1]] # ClearTensor<uint1, shape=(2, 3)>
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(3, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1098,13 +1110,15 @@ return(%9)
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
|
||||
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
|
||||
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = [[1 1 1] [1 1 1]] # ClearTensor<uint1, shape=(2, 3)>
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(3, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being
|
||||
return %2
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1112,13 +1126,17 @@ return(%9)
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n"
|
||||
"\n"
|
||||
"%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>\n" # noqa: E501
|
||||
"%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor<Integer<unsigned, 1 bits>, shape=(2, 3)>\n" # noqa: E501
|
||||
"%2 = MatMul(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(3, 3)>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = [[1 1 1] [1 1 1]] # ClearTensor<uint1, shape=(2, 3)>
|
||||
%2 = matmul(%0, %1) # EncryptedTensor<uint4, shape=(3, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1127,13 +1145,15 @@ return(%9)
|
||||
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(32)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 2)>
|
||||
%1 = MultiTLU(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being
|
||||
return(%1)
|
||||
""".lstrip() # noqa: E501
|
||||
%0 = x # EncryptedTensor<uint2, shape=(3, 2)>
|
||||
%1 = MultiTLU(%0, input_shape=(3, 2), tables=[[[1, 2, 1 ... 1, 2, 0]]]) # EncryptedTensor<uint2, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1141,13 +1161,16 @@ return(%1)
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"""function you are trying to compile isn't supported for MLIR lowering
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>
|
||||
%1 = np.transpose(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.transpose of kind Memory is not supported for the time being
|
||||
return(%1)
|
||||
""" # noqa: E501
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = transpose(%0) # EncryptedTensor<uint3, shape=(2, 3)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ transpose is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1155,13 +1178,16 @@ return(%1)
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)],
|
||||
(
|
||||
"""function you are trying to compile isn't supported for MLIR lowering
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 2)>
|
||||
%1 = np.ravel(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(6,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.ravel of kind Memory is not supported for the time being
|
||||
return(%1)
|
||||
""" # noqa: E501
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = ravel(%0) # EncryptedTensor<uint3, shape=(6,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ravel is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
@@ -1169,13 +1195,16 @@ return(%1)
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 4))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 4)),) for i in range(10)],
|
||||
(
|
||||
"""function you are trying to compile isn't supported for MLIR lowering
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(3, 4)>
|
||||
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 6)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.reshape of kind Memory is not supported for the time being
|
||||
return(%1)
|
||||
""" # noqa: E501
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 4)>
|
||||
%1 = reshape(%0, newshape=(2, 6)) # EncryptedTensor<uint3, shape=(2, 6)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ reshape is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
],
|
||||
@@ -1217,19 +1246,21 @@ def test_fail_with_intermediate_signed_values(default_compilation_configuration)
|
||||
)
|
||||
except RuntimeError as error:
|
||||
match = """
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = y # EncryptedScalar<Integer<unsigned, 2 bits>>
|
||||
%1 = Constant(10) # ClearScalar<Integer<unsigned, 4 bits>>
|
||||
%2 = x # EncryptedScalar<Integer<unsigned, 2 bits>>
|
||||
%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>
|
||||
%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>
|
||||
%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.absolute is not supported for the time being
|
||||
%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>
|
||||
%7 = Add(%6, %0) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
return(%7)
|
||||
""".lstrip() # noqa: E501 # pylint: disable=line-too-long
|
||||
%0 = y # EncryptedScalar<uint2>
|
||||
%1 = 10 # ClearScalar<uint4>
|
||||
%2 = x # EncryptedScalar<uint2>
|
||||
%3 = negative(%2) # EncryptedScalar<int3>
|
||||
%4 = mul(%3, %1) # EncryptedScalar<int6>
|
||||
%5 = absolute(%4) # EncryptedScalar<uint5>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ absolute is not supported for the time being
|
||||
%6 = astype(%5, dtype=int32) # EncryptedScalar<uint5>
|
||||
%7 = add(%6, %0) # EncryptedScalar<uint6>
|
||||
return %7
|
||||
|
||||
""".strip() # noqa: E501 # pylint: disable=line-too-long
|
||||
assert str(error) == match
|
||||
raise
|
||||
|
||||
@@ -1270,13 +1301,14 @@ def test_small_inputset_treat_warnings_as_errors():
|
||||
(4,),
|
||||
# Remark that, when you do the dot of tensors of 4 values between 0 and 3,
|
||||
# you can get a maximal value of 4*3*3 = 36, ie something on 6 bits
|
||||
"%0 = x "
|
||||
"# EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)>"
|
||||
"\n%1 = y "
|
||||
"# EncryptedTensor<Integer<unsigned, 2 bits>, shape=(4,)>"
|
||||
"\n%2 = Dot(%0, %1) "
|
||||
"# EncryptedScalar<Integer<unsigned, 6 bits>>"
|
||||
"\nreturn(%2)\n",
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedTensor<uint2, shape=(4,)>
|
||||
%1 = y # EncryptedTensor<uint2, shape=(4,)>
|
||||
%2 = dot(%0, %1) # EncryptedScalar<uint6>
|
||||
return %2
|
||||
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -1303,7 +1335,7 @@ def test_compile_function_with_dot(
|
||||
data_gen(max_for_ij, repeat),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph)
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
@@ -1373,13 +1405,16 @@ def test_compile_too_high_bitwidth(default_compilation_configuration):
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== """
|
||||
|
||||
max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 7) which is not compatible with:
|
||||
%0 = x # EncryptedScalar<Integer<unsigned, 7 bits>>
|
||||
%1 = y # EncryptedScalar<Integer<unsigned, 5 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 8 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 bits is not supported for the time being
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501 # pylint: disable=line-too-long
|
||||
|
||||
%0 = x # EncryptedScalar<uint7>
|
||||
%1 = y # EncryptedScalar<uint5>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint8>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 bits is not supported for the time being
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501 # pylint: disable=line-too-long
|
||||
)
|
||||
|
||||
# Just ok
|
||||
@@ -1418,14 +1453,16 @@ def test_failure_for_signed_output(default_compilation_configuration):
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== """
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%1 = Constant(-3) # ClearScalar<Integer<signed, 3 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501 # pylint: disable=line-too-long
|
||||
%0 = x # EncryptedScalar<uint4>
|
||||
%1 = -3 # ClearScalar<int3>
|
||||
%2 = add(%0, %1) # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501 # pylint: disable=line-too-long
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -368,8 +368,7 @@ def test_constant_indexing(
|
||||
EncryptedScalar(UnsignedInteger(1)),
|
||||
lambda x: x[0],
|
||||
TypeError,
|
||||
"Only tensors can be indexed "
|
||||
"but you tried to index EncryptedScalar<Integer<unsigned, 1 bits>>",
|
||||
"Only tensors can be indexed but you tried to index EncryptedScalar<uint1>",
|
||||
),
|
||||
pytest.param(
|
||||
EncryptedTensor(UnsignedInteger(1), shape=(3,)),
|
||||
|
||||
@@ -1,449 +0,0 @@
|
||||
"""Test file for debugging functions"""
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import draw_graph, format_operation_graph
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
|
||||
LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11])
|
||||
LOOKUP_TABLE_FROM_3B_TO_2B = LookupTable([0, 1, 3, 2, 2, 3, 1, 0])
|
||||
|
||||
|
||||
def issue_130_a(x, y):
|
||||
"""Test case derived from issue #130"""
|
||||
# pylint: disable=unused-argument
|
||||
intermediate = x + 1
|
||||
return (intermediate, intermediate)
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
|
||||
def issue_130_b(x, y):
|
||||
"""Test case derived from issue #130"""
|
||||
# pylint: disable=unused-argument
|
||||
intermediate = x - 1
|
||||
return (intermediate, intermediate)
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
|
||||
def issue_130_c(x, y):
|
||||
"""Test case derived from issue #130"""
|
||||
# pylint: disable=unused-argument
|
||||
intermediate = 1 - x
|
||||
return (intermediate, intermediate)
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,ref_graph_str",
|
||||
[
|
||||
(lambda x, y: x + y, "%0 = x\n%1 = y\n%2 = Add(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: x - y, "%0 = x\n%1 = y\n%2 = Sub(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: x + x, "%0 = x\n%1 = Add(%0, %0)\nreturn(%1)\n"),
|
||||
(
|
||||
lambda x, y: x + x - y * y * y + x,
|
||||
"%0 = x\n%1 = y\n%2 = Add(%0, %0)\n%3 = Mul(%1, %1)"
|
||||
"\n%4 = Mul(%3, %1)\n%5 = Sub(%2, %4)\n%6 = Add(%5, %0)\nreturn(%6)\n",
|
||||
),
|
||||
(lambda x, y: x + 1, "%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: 1 + x, "%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: (-1) + x, "%0 = x\n%1 = Constant(-1)\n%2 = Add(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: 3 * x, "%0 = x\n%1 = Constant(3)\n%2 = Mul(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: x * 3, "%0 = x\n%1 = Constant(3)\n%2 = Mul(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: x * (-3), "%0 = x\n%1 = Constant(-3)\n%2 = Mul(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: x - 11, "%0 = x\n%1 = Constant(11)\n%2 = Sub(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: 11 - x, "%0 = Constant(11)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2)\n"),
|
||||
(lambda x, y: (-11) - x, "%0 = Constant(-11)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2)\n"),
|
||||
(
|
||||
lambda x, y: x + 13 - y * (-21) * y + 44,
|
||||
"%0 = Constant(44)"
|
||||
"\n%1 = x"
|
||||
"\n%2 = Constant(13)"
|
||||
"\n%3 = y"
|
||||
"\n%4 = Constant(-21)"
|
||||
"\n%5 = Add(%1, %2)"
|
||||
"\n%6 = Mul(%3, %4)"
|
||||
"\n%7 = Mul(%6, %3)"
|
||||
"\n%8 = Sub(%5, %7)"
|
||||
"\n%9 = Add(%8, %0)"
|
||||
"\nreturn(%9)\n",
|
||||
),
|
||||
# Multiple outputs
|
||||
(
|
||||
lambda x, y: (x + 1, x + y + 2),
|
||||
"%0 = x"
|
||||
"\n%1 = Constant(1)"
|
||||
"\n%2 = Constant(2)"
|
||||
"\n%3 = y"
|
||||
"\n%4 = Add(%0, %1)"
|
||||
"\n%5 = Add(%0, %3)"
|
||||
"\n%6 = Add(%5, %2)"
|
||||
"\nreturn(%4, %6)\n",
|
||||
),
|
||||
(
|
||||
lambda x, y: (y, x),
|
||||
"%0 = y\n%1 = x\nreturn(%0, %1)\n",
|
||||
),
|
||||
(
|
||||
lambda x, y: (x, x + 1),
|
||||
"%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%0, %2)\n",
|
||||
),
|
||||
(
|
||||
lambda x, y: (x + 1, x + 1),
|
||||
"%0 = x"
|
||||
"\n%1 = Constant(1)"
|
||||
"\n%2 = Constant(1)"
|
||||
"\n%3 = Add(%0, %1)"
|
||||
"\n%4 = Add(%0, %2)"
|
||||
"\nreturn(%3, %4)\n",
|
||||
),
|
||||
(
|
||||
issue_130_a,
|
||||
"%0 = x\n%1 = Constant(1)\n%2 = Add(%0, %1)\nreturn(%2, %2)\n",
|
||||
),
|
||||
(
|
||||
issue_130_b,
|
||||
"%0 = x\n%1 = Constant(1)\n%2 = Sub(%0, %1)\nreturn(%2, %2)\n",
|
||||
),
|
||||
(
|
||||
issue_130_c,
|
||||
"%0 = Constant(1)\n%1 = x\n%2 = Sub(%0, %1)\nreturn(%2, %2)\n",
|
||||
),
|
||||
(
|
||||
lambda x, y: numpy.arctan2(x, 42) + y,
|
||||
"""%0 = y
|
||||
%1 = x
|
||||
%2 = Constant(42)
|
||||
%3 = np.arctan2(%1, %2)
|
||||
%4 = Add(%3, %0)
|
||||
return(%4)
|
||||
""",
|
||||
),
|
||||
(
|
||||
lambda x, y: numpy.arctan2(43, x) + y,
|
||||
"""%0 = y
|
||||
%1 = Constant(43)
|
||||
%2 = x
|
||||
%3 = np.arctan2(%1, %2)
|
||||
%4 = Add(%3, %0)
|
||||
return(%4)
|
||||
""",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"x_y",
|
||||
[
|
||||
pytest.param(
|
||||
(
|
||||
EncryptedScalar(Integer(64, is_signed=False)),
|
||||
EncryptedScalar(Integer(64, is_signed=False)),
|
||||
),
|
||||
id="Encrypted uint",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
EncryptedScalar(Integer(64, is_signed=False)),
|
||||
ClearScalar(Integer(64, is_signed=False)),
|
||||
),
|
||||
id="Clear uint",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
"Test format_operation_graph and draw_graph"
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = format_operation_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
f"==================\n"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
"%0 = x\n%1 = TLU(%0)\nreturn(%1)\n",
|
||||
),
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
"%0 = x\n%1 = Constant(4)\n%2 = Add(%0, %1)\n%3 = TLU(%2)\nreturn(%3)\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
"Test format_operation_graph and draw_graph on graphs with direct table lookup"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = format_operation_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
f"==================\n"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
|
||||
"y": EncryptedTensor(Integer(2, is_signed=False), shape=(3,)),
|
||||
},
|
||||
"%0 = x\n%1 = y\n%2 = Dot(%0, %1)\nreturn(%2)\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
"Test format_operation_graph and draw_graph on graphs with dot"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = format_operation_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
f"==================\n"
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
(
|
||||
lambda x: numpy.transpose(x),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
|
||||
},
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
|
||||
%1 = np.transpose(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(5, 3)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
(
|
||||
lambda x: numpy.ravel(x),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
|
||||
},
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
|
||||
%1 = np.ravel(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(15,)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (5, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)),
|
||||
},
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 5)>
|
||||
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(5, 3)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (170,)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)),
|
||||
},
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(17, 10)>
|
||||
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(170,)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
(
|
||||
lambda x: numpy.reshape(x, (170)),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)),
|
||||
},
|
||||
"""
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(17, 10)>
|
||||
%1 = np.reshape(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(170,)>
|
||||
return(%1)
|
||||
""".lstrip(), # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str):
|
||||
"Test format_operation_graph and draw_graph on graphs with generic function"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
f"==================\n"
|
||||
)
|
||||
|
||||
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
|
||||
# returning 23b), since they are replaced later by the real bitwidths computed on the
|
||||
# inputset
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,x_y,ref_graph_str",
|
||||
[
|
||||
(
|
||||
lambda x, y: x + y,
|
||||
(
|
||||
EncryptedScalar(Integer(64, is_signed=False)),
|
||||
EncryptedScalar(Integer(32, is_signed=True)),
|
||||
),
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 64 bits>>"
|
||||
"\n%1 = y "
|
||||
" # EncryptedScalar<Integer<signed, 32 bits>>"
|
||||
"\n%2 = Add(%0, %1) "
|
||||
" # EncryptedScalar<Integer<signed, 65 bits>>"
|
||||
"\nreturn(%2)\n",
|
||||
),
|
||||
(
|
||||
lambda x, y: x * y,
|
||||
(
|
||||
EncryptedScalar(Integer(17, is_signed=False)),
|
||||
EncryptedScalar(Integer(23, is_signed=False)),
|
||||
),
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 17 bits>>"
|
||||
"\n%1 = y "
|
||||
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
|
||||
"\n%2 = Mul(%0, %1) "
|
||||
"# EncryptedScalar<Integer<unsigned, 23 bits>>"
|
||||
"\nreturn(%2)\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
"""Test format_operation_graph with show_data_types"""
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
f"==================\n"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\n%1 = TLU(%0) "
|
||||
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
|
||||
"\nreturn(%1)\n",
|
||||
),
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\n%1 = Constant(4) "
|
||||
"# ClearScalar<Integer<unsigned, 3 bits>>"
|
||||
"\n%2 = Add(%0, %1) "
|
||||
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
|
||||
"\n%3 = TLU(%2) "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\nreturn(%3)\n",
|
||||
),
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]],
|
||||
{"x": EncryptedScalar(Integer(2, is_signed=False))},
|
||||
"%0 = x "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\n%1 = Constant(4) "
|
||||
"# ClearScalar<Integer<unsigned, 3 bits>>"
|
||||
"\n%2 = Add(%0, %1) "
|
||||
"# EncryptedScalar<Integer<unsigned, 3 bits>>"
|
||||
"\n%3 = TLU(%2) "
|
||||
"# EncryptedScalar<Integer<unsigned, 2 bits>>"
|
||||
"\n%4 = TLU(%3) "
|
||||
"# EncryptedScalar<Integer<unsigned, 4 bits>>"
|
||||
"\nreturn(%4)\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
"""Test format_operation_graph with show_data_types on graphs with direct table lookup"""
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
f"==================\n"
|
||||
)
|
||||
|
||||
|
||||
def test_numpy_long_constant():
|
||||
"Test format_operation_graph with long constant"
|
||||
|
||||
def all_explicit_operations(x):
|
||||
intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10))
|
||||
intermediate = numpy.subtract(intermediate, numpy.arange(10).reshape(1, 10))
|
||||
intermediate = numpy.arctan2(numpy.arange(10, 20).reshape(1, 10), intermediate)
|
||||
intermediate = numpy.arctan2(numpy.arange(100, 200).reshape(10, 10), intermediate)
|
||||
return intermediate
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(10, 10))}
|
||||
)
|
||||
|
||||
expected = """
|
||||
%0 = Constant([[100 101 ... 198 199]]) # ClearTensor<Integer<unsigned, 8 bits>, shape=(10, 10)>
|
||||
%1 = Constant([[10 11 12 ... 17 18 19]]) # ClearTensor<Integer<unsigned, 5 bits>, shape=(1, 10)>
|
||||
%2 = Constant([[0 1 2 3 4 5 6 7 8 9]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(1, 10)>
|
||||
%3 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%4 = Constant([[ 0 1 2 ... 97 98 99]]) # ClearTensor<Integer<unsigned, 7 bits>, shape=(10, 10)>
|
||||
%5 = Add(%3, %4) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%6 = Sub(%5, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(10, 10)>
|
||||
%7 = np.arctan2(%1, %6) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
%8 = np.arctan2(%0, %7) # EncryptedTensor<Float<64 bits>, shape=(10, 10)>
|
||||
return(%8)
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == expected, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{expected}"
|
||||
f"==================\n"
|
||||
)
|
||||
@@ -188,24 +188,22 @@ def test_numpy_tracing_tensors():
|
||||
all_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
|
||||
)
|
||||
|
||||
expected = """
|
||||
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
|
||||
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
|
||||
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(%5, %6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(%4, %7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(%3, %8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(%9, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(%10, %1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(%0, %11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
return(%12)
|
||||
""".lstrip() # noqa: E501
|
||||
expected = """ %0 = [[2 1] [1 2]] # ClearTensor<uint2, shape=(2, 2)>
|
||||
%1 = [[1 2] [2 1]] # ClearTensor<uint2, shape=(2, 2)>
|
||||
%2 = [[10 20] [30 40]] # ClearTensor<uint6, shape=(2, 2)>
|
||||
%3 = [[100 200] [300 400]] # ClearTensor<uint9, shape=(2, 2)>
|
||||
%4 = [[5 6] [7 8]] # ClearTensor<uint4, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%6 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
|
||||
%7 = add(%5, %6) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%8 = add(%4, %7) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%9 = sub(%3, %8) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%10 = sub(%9, %2) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%11 = mul(%10, %1) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%12 = mul(%0, %11) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
return %12""" # noqa: E501
|
||||
|
||||
assert format_operation_graph(op_graph, show_data_types=True) == expected
|
||||
assert format_operation_graph(op_graph) == expected, format_operation_graph(op_graph)
|
||||
|
||||
|
||||
def test_numpy_explicit_tracing_tensors():
|
||||
@@ -227,24 +225,22 @@ def test_numpy_explicit_tracing_tensors():
|
||||
all_explicit_operations, {"x": EncryptedTensor(Integer(32, True), shape=(2, 2))}
|
||||
)
|
||||
|
||||
expected = """
|
||||
%0 = Constant([[2 1] [1 2]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%1 = Constant([[1 2] [2 1]]) # ClearTensor<Integer<unsigned, 2 bits>, shape=(2, 2)>
|
||||
%2 = Constant([[10 20] [30 40]]) # ClearTensor<Integer<unsigned, 6 bits>, shape=(2, 2)>
|
||||
%3 = Constant([[100 200] [300 400]]) # ClearTensor<Integer<unsigned, 9 bits>, shape=(2, 2)>
|
||||
%4 = Constant([[5 6] [7 8]]) # ClearTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%6 = Constant([[1 2] [3 4]]) # ClearTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%7 = Add(%5, %6) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%8 = Add(%4, %7) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%9 = Sub(%3, %8) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%10 = Sub(%9, %2) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%11 = Mul(%10, %1) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
%12 = Mul(%0, %11) # EncryptedTensor<Integer<signed, 32 bits>, shape=(2, 2)>
|
||||
return(%12)
|
||||
""".lstrip() # noqa: E501
|
||||
expected = """ %0 = [[2 1] [1 2]] # ClearTensor<uint2, shape=(2, 2)>
|
||||
%1 = [[1 2] [2 1]] # ClearTensor<uint2, shape=(2, 2)>
|
||||
%2 = [[10 20] [30 40]] # ClearTensor<uint6, shape=(2, 2)>
|
||||
%3 = [[100 200] [300 400]] # ClearTensor<uint9, shape=(2, 2)>
|
||||
%4 = [[5 6] [7 8]] # ClearTensor<uint4, shape=(2, 2)>
|
||||
%5 = x # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%6 = [[1 2] [3 4]] # ClearTensor<uint3, shape=(2, 2)>
|
||||
%7 = add(%5, %6) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%8 = add(%4, %7) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%9 = sub(%3, %8) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%10 = sub(%9, %2) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%11 = mul(%10, %1) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
%12 = mul(%0, %11) # EncryptedTensor<int32, shape=(2, 2)>
|
||||
return %12""" # noqa: E501
|
||||
|
||||
assert format_operation_graph(op_graph, show_data_types=True) == expected
|
||||
assert format_operation_graph(op_graph) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user