feat: implement representation module

This commit is contained in:
Umut
2022-04-04 13:27:37 +02:00
parent 5baa96664b
commit 92efb20e09
5 changed files with 964 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
"""
Declaration of `concrete.numpy.representation` namespace.
"""
from .graph import Graph
from .node import Node
from .operation import Operation

View File

@@ -0,0 +1,498 @@
"""
Declaration of `Graph` class.
"""
import os
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from PIL import Image
from ..dtypes import Float, Integer, UnsignedInteger
from .node import Node
from .operation import OPERATION_COLOR_MAPPING, Operation
class Graph:
"""
Graph class, to represent computation graphs.
"""
graph: nx.MultiDiGraph
input_nodes: Dict[int, Node]
output_nodes: Dict[int, Node]
input_indices: Dict[Node, int]
def __init__(
self,
graph: nx.MultiDiGraph,
input_nodes: Dict[int, Node],
output_nodes: Dict[int, Node],
):
self.graph = graph
self.input_nodes = input_nodes
self.output_nodes = output_nodes
self.input_indices = {node: index for index, node in input_nodes.items()}
self.prune_useless_nodes()
def __call__(
self,
*args: Any,
) -> Union[
np.bool_,
np.integer,
np.floating,
np.ndarray,
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
]:
evaluation = self.evaluate(*args)
result = tuple(evaluation[node] for node in self.ordered_outputs())
return result if len(result) > 1 else result[0]
def evaluate(
self,
*args: Any,
) -> Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]]:
r"""
Perform the computation `Graph` represents and get resulting values for all nodes.
Args:
*args (List[Any]):
inputs to the computation
Returns:
Dict[Node, Union[np.bool\_, np.integer, np.floating, np.ndarray]]:
nodes and their values during computation
"""
node_results: Dict[Node, Union[np.bool_, np.integer, np.floating, np.ndarray]] = {}
for node in nx.topological_sort(self.graph):
if node.operation == Operation.Input:
node_results[node] = node(args[self.input_indices[node]])
continue
pred_results = [node_results[pred] for pred in self.ordered_preds_of(node)]
node_results[node] = node(*pred_results)
return node_results
def draw(
self,
show: bool = False,
horizontal: bool = False,
save_to: Optional[Union[Path, str]] = None,
) -> Path:
"""
Draw the `Graph` and optionally save/show the drawing.
note that this function requires the python `pygraphviz` package
which itself requires the installation of `graphviz` packages
see https://pygraphviz.github.io/documentation/stable/install.html
Args:
show (bool, default = False):
whether to show the drawing using matplotlib or not
horizontal (bool, default = False):
whether to draw horizontally or not
save_to (Optional[Path], default = None):
path to save the drawing
a temporary file will be used if it's None
Returns:
Path:
path to the saved drawing
"""
def get_color(node, output_nodes):
if node in output_nodes:
return OPERATION_COLOR_MAPPING["output"]
return OPERATION_COLOR_MAPPING[node.operation]
graph = self.graph
output_nodes = set(self.output_nodes.values())
attributes = {
node: {
"label": node.label(),
"color": get_color(node, output_nodes),
"penwidth": 2, # double thickness for circles
"peripheries": 2 if node in output_nodes else 1, # two circles for output nodes
}
for node in graph.nodes
}
nx.set_node_attributes(graph, attributes)
for edge in graph.edges(keys=True):
idx = graph.edges[edge]["input_idx"]
graph.edges[edge]["label"] = f" {idx} "
try:
agraph = nx.nx_agraph.to_agraph(graph)
except ImportError as error: # pragma: no cover
if "pygraphviz" in str(error):
raise ImportError(
"Graph.draw requires pygraphviz. Install graphviz distribution to your OS "
"following https://pygraphviz.github.io/documentation/stable/install.html "
"and reinstall concrete-numpy with extras: `pip install --force-reinstall "
"concrete-numpy[full]`"
) from error
raise
agraph.graph_attr["rankdir"] = "LR" if horizontal else "TB"
agraph.layout("dot")
if save_to is None:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
# we need to change the permissions of the temporary file
# so that it can be read by all users
# (https://stackoverflow.com/a/44130605)
# get the old umask and replace it with 0o666
old_umask = os.umask(0o666)
# restore the old umask back
os.umask(old_umask)
# combine the old umask with the wanted permissions
permissions = 0o666 & ~old_umask
# set new permissions
os.chmod(tmp.name, permissions)
save_to_str = str(tmp.name)
else:
save_to_str = str(save_to)
agraph.draw(save_to_str)
if show: # pragma: no cover
plt.close("all")
plt.figure()
img = Image.open(save_to_str)
plt.imshow(img)
img.close()
plt.axis("off")
plt.show()
return Path(save_to_str)
def format(
self,
maximum_constant_length: int = 25,
highlighted_nodes: Optional[Dict[Node, List[str]]] = None,
) -> str:
"""
Get the textual representation of the `Graph`.
Args:
maximum_constant_length (int, default = 25):
maximum length of formatted constants
highlighted_nodes (Optional[Dict[Node, List[str]]], default = None):
nodes to be highlighted and their corresponding messages
Returns:
str:
textual representation of the `Graph`
"""
# node -> identifier
# e.g., id_map[node1] = 2
# means line for node1 is in this form %2 = node1.format(...)
id_map: Dict[Node, int] = {}
# lines that will be merged at the end
lines: List[str] = []
# type information to add to each line
# (for alingment, this is done after lines are determined)
type_informations: List[str] = []
# default highlighted nodes is empty
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
# highlight information for lines, this is required because highlights are added to lines
# after their type information is added, and we only have line numbers, not nodes
highlighted_lines: Dict[int, List[str]] = {}
# subgraphs to format after the main graph is formatted
subgraphs: Dict[str, Graph] = {}
# format nodes
for node in nx.topological_sort(self.graph):
# assign a unique id to outputs of node
id_map[node] = len(id_map)
# remember highlights of the node
if node in highlighted_nodes:
highlighted_lines[len(lines)] = highlighted_nodes[node]
# extract predecessors and their ids
predecessors = []
for predecessor in self.ordered_preds_of(node):
predecessors.append(f"%{id_map[predecessor]}")
# start the build the line for the node
line = ""
# add output information to the line
line += f"%{id_map[node]}"
# add node information to the line
line += " = "
line += node.format(predecessors, maximum_constant_length)
# append line to list of lines
lines.append(line)
# if exists, save the subgraph
if node.operation == Operation.Generic and "subgraph" in node.properties["kwargs"]:
subgraphs[line] = node.properties["kwargs"]["subgraph"]
# remember type information of the node
type_informations.append(str(node.output))
# align = signs
#
# e.g.,
#
# %1 = ...
# %2 = ...
# ...
# %8 = ...
# %9 = ...
# %10 = ...
# %11 = ...
# ...
longest_length_before_equals_sign = max(len(line.split("=")[0]) for line in lines)
for i, line in enumerate(lines):
length_before_equals_sign = len(line.split("=")[0])
lines[i] = (
" " * (longest_length_before_equals_sign - length_before_equals_sign)
) + line
# add type information
longest_line_length = max(len(line) for line in lines)
for i, line in enumerate(lines):
lines[i] += " " * (longest_line_length - len(line))
lines[i] += f" # {type_informations[i]}"
# add highlights (this is done in reverse to keep indices consistent)
for i in reversed(range(len(lines))):
if i in highlighted_lines:
for j, message in enumerate(highlighted_lines[i]):
highlight = "^" if j == 0 else " "
lines.insert(i + 1 + j, f"{highlight * len(lines[i])} {message}")
# add return information
# (if there is a single return, it's in the form `return %id`
# (otherwise, it's in the form `return (%id1, %id2, ..., %idN)`
returns: List[str] = []
for node in self.output_nodes.values():
returns.append(f"%{id_map[node]}")
lines.append("return " + (returns[0] if len(returns) == 1 else f"({', '.join(returns)})"))
# format subgraphs after the actual graph
result = "\n".join(lines)
if len(subgraphs) > 0:
result += "\n\n"
result += "Subgraphs:"
for line, subgraph in subgraphs.items():
subgraph_lines = subgraph.format(maximum_constant_length).split("\n")
result += "\n\n"
result += f" {line}:\n\n"
result += "\n".join(f" {line}" for line in subgraph_lines)
return result
def measure_bounds(
self,
inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]],
) -> Dict[Node, Dict[str, Union[np.integer, np.floating]]]:
"""
Evaluate the `Graph` using an inputset and measure bounds.
inputset is either an iterable of anything
for a single parameter
or
an iterable of tuples of anything (of rank number of parameters)
for multiple parameters
e.g.,
.. code-block:: python
inputset = [1, 3, 5, 2, 4]
def f(x):
...
inputset = [(1, 2), (2, 4), (3, 1), (2, 2)]
def g(x, y):
...
Args:
inputset (Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
inputset to use
Returns:
Dict[Node, Dict[str, Union[np.integer, np.floating]]]:
bounds of each node in the `Graph`
"""
bounds = {}
inputset_iterator = iter(inputset)
sample = next(inputset_iterator)
if not isinstance(sample, tuple):
sample = (sample,)
evaluation = self.evaluate(*sample)
for node, value in evaluation.items():
bounds[node] = {
"min": value.min(),
"max": value.max(),
}
for sample in inputset_iterator:
if not isinstance(sample, tuple):
sample = (sample,)
evaluation = self.evaluate(*sample)
for node, value in evaluation.items():
bounds[node] = {
"min": np.minimum(bounds[node]["min"], value.min()),
"max": np.maximum(bounds[node]["max"], value.max()),
}
return bounds
def update_with_bounds(self, bounds: Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
"""
Update `Value`s within the `Graph` according to measured bounds.
Args:
bounds (Dict[Node, Dict[str, Union[np.integer, np.floating]]]):
bounds of each node in the `Graph`
"""
for node in self.graph.nodes():
if node in bounds:
min_bound = bounds[node]["min"]
max_bound = bounds[node]["max"]
new_value = deepcopy(node.output)
if isinstance(min_bound, np.integer):
new_value.dtype = Integer.that_can_represent(np.array([min_bound, max_bound]))
else:
new_value.dtype = {
np.bool_: UnsignedInteger(1),
np.float64: Float(64),
np.float32: Float(32),
np.float16: Float(16),
}[type(min_bound)]
node.output = new_value
if node.operation == Operation.Input:
node.inputs[0] = new_value
for successor in self.graph.successors(node):
edge_data = self.graph.get_edge_data(node, successor)
for edge in edge_data.values():
input_idx = edge["input_idx"]
successor.inputs[input_idx] = node.output
def ordered_inputs(self) -> List[Node]:
"""
Get the input nodes of the `Graph`, ordered by their indices.
Returns:
List[Node]:
ordered input nodes
"""
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
def ordered_outputs(self) -> List[Node]:
"""
Get the output nodes of the `Graph`, ordered by their indices.
Returns:
List[Node]:
ordered output nodes
"""
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
def ordered_preds_of(self, node: Node) -> List[Node]:
"""
Get predecessors of `node`, ordered by their indices.
Args:
node (Node):
node whose predecessors are requested
Returns:
List[Node]:
ordered predecessors of `node`.
"""
idx_to_pred: Dict[int, Node] = {}
for pred in self.graph.predecessors(node):
edge_data = self.graph.get_edge_data(pred, node)
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
def prune_useless_nodes(self):
"""
Remove unreachable nodes from the graph.
"""
useful_nodes: Dict[Node, None] = {}
current_nodes = {node: None for node in self.ordered_outputs()}
while current_nodes:
useful_nodes.update(current_nodes)
next_nodes: Dict[Node, None] = {}
for node in current_nodes:
next_nodes.update({node: None for node in self.graph.predecessors(node)})
current_nodes = next_nodes
useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes]
self.graph.remove_nodes_from(useless_nodes)
def maximum_integer_bit_width(self) -> int:
"""
Get maximum integer bit-width within the graph.
Returns:
int:
maximum integer bit-width within the graph (-1 is there are no integer nodes)
"""
result = -1
for node in self.graph.nodes():
if isinstance(node.output.dtype, Integer):
result = max(result, node.output.dtype.bit_width)
return result

View File

@@ -0,0 +1,316 @@
"""
Declaration of `Node` class.
"""
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from ..internal.utils import assert_that
from ..values import Value
from .operation import Operation
from .utils import KWARGS_IGNORED_IN_FORMATTING, format_constant, format_indexing_element
class Node:
"""
Node class, to represent computation in a computation graph.
"""
inputs: List[Value]
output: Value
operation: Operation
evaluator: Callable
properties: Dict[str, Any]
@staticmethod
def constant(constant: Any) -> "Node":
"""
Create an Operation.Constant node.
Args:
constant (Any):
constant to represent
Returns:
Node:
node representing constant
Raises:
ValueError:
if the constant is not representable
"""
try:
value = Value.of(constant)
except Exception as error:
raise ValueError(f"Constant {repr(constant)} is not supported") from error
properties = {"constant": np.array(constant)}
return Node([], value, Operation.Constant, lambda: properties["constant"], properties)
@staticmethod
def generic(
name: str,
inputs: List[Value],
output: Value,
operation: Callable,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
attributes: Optional[Dict[str, Any]] = None,
):
"""
Create an Operation.Generic node.
Args:
name (str):
name of the operation
inputs (List[Value]):
inputs to the operation
output (Value):
output of the operation
operation (Callable):
operation itself
args (Optional[Tuple[Any, ...]]):
args to pass to operation during evaluation
kwargs (Optional[Dict[str, Any]]):
kwargs to pass to operation during evaluation
attributes (Optional[Dict[str, Any]]):
attributes of the operation
Returns:
Node:
node representing operation
"""
properties = {
"name": name,
"args": args if args is not None else (),
"kwargs": kwargs if kwargs is not None else {},
"attributes": attributes if attributes is not None else {},
}
if name == "concatenate":
def evaluator(*args):
return operation(tuple(args), *properties["args"], **properties["kwargs"])
else:
def evaluator(*args):
return operation(*args, *properties["args"], **properties["kwargs"])
return Node(inputs, output, Operation.Generic, evaluator, properties)
@staticmethod
def input(name: str, value: Value) -> "Node":
"""
Create an Operation.Input node.
Args:
name (Any):
name of the input
value (Any):
value of the input
Returns:
Node:
node representing input
"""
return Node([value], value, Operation.Input, lambda x: x, {"name": name})
def __init__(
self,
inputs: List[Value],
output: Value,
operation: Operation,
evaluator: Callable,
properties: Optional[Dict[str, Any]] = None,
):
self.inputs = inputs
self.output = output
self.operation = operation
self.evaluator = evaluator # type: ignore
self.properties = properties if properties is not None else {}
def __call__(self, *args: List[Any]) -> Union[np.bool_, np.integer, np.floating, np.ndarray]:
def generic_error_message() -> str:
result = f"Evaluation of {self.operation.value} '{self.label()}' node"
if len(args) != 0:
result += f" using {', '.join(repr(arg) for arg in args)}"
return result
if len(args) != len(self.inputs):
raise ValueError(
f"{generic_error_message()} failed because of invalid number of arguments"
)
for arg, input_ in zip(args, self.inputs):
try:
arg_value = Value.of(arg)
except Exception as error:
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
raise ValueError(
f"{generic_error_message()} failed because {arg_str} is not valid"
) from error
if input_.shape != arg_value.shape:
arg_str = "the argument" if len(args) == 1 else f"argument {repr(arg)}"
raise ValueError(
f"{generic_error_message()} failed because "
f"{arg_str} does not have the expected "
f"shape of {input_.shape}"
)
result = self.evaluator(*args)
if isinstance(result, int) and -(2 ** 63) < result < (2 ** 63) - 1:
result = np.int64(result)
if isinstance(result, float):
result = np.float64(result)
if isinstance(result, list):
try:
np_result = np.array(result)
result = np_result
except Exception: # pylint: disable=broad-except
# here we try our best to convert the list to np.ndarray
# if it fails we raise the exception below
pass
if not isinstance(result, (np.bool_, np.integer, np.floating, np.ndarray)):
raise ValueError(
f"{generic_error_message()} resulted in {repr(result)} "
f"of type {result.__class__.__name__} "
f"which is not acceptable either because of the type or because of overflow"
)
if isinstance(result, np.ndarray):
dtype = result.dtype
if (
not np.issubdtype(dtype, np.integer)
and not np.issubdtype(dtype, np.floating)
and not np.issubdtype(dtype, np.bool_)
):
raise ValueError(
f"{generic_error_message()} resulted in {repr(result)} "
f"of type np.ndarray and of underlying type '{type(dtype).__name__}' "
f"which is not acceptable because of the underlying type"
)
if result.shape != self.output.shape:
raise ValueError(
f"{generic_error_message()} resulted in {repr(result)} "
f"which does not have the expected "
f"shape of {self.output.shape}"
)
return result
def format(self, predecessors: List[str], maximum_constant_length: int = 45) -> str:
"""
Get the textual representation of the `Node` (for printing).
Args:
predecessors (List[str]):
predecessor names to this node
maximum_constant_length (int, default = 45):
maximum length of formatted constants
Returns:
str:
textual representation of the `Node` (for printing)
"""
if self.operation == Operation.Constant:
return format_constant(self(), maximum_constant_length)
if self.operation == Operation.Input:
return self.properties["name"]
assert_that(self.operation == Operation.Generic)
name = self.properties["name"]
if name == "index.static":
index = self.properties["attributes"]["index"]
elements = [format_indexing_element(element) for element in index]
return f"{predecessors[0]}[{', '.join(elements)}]"
if name == "concatenate":
args = [f"({', '.join(predecessors)})"]
else:
args = deepcopy(predecessors)
args.extend(
format_constant(value, maximum_constant_length) for value in self.properties["args"]
)
args.extend(
f"{name}={format_constant(value, maximum_constant_length)}"
for name, value in self.properties["kwargs"].items()
if name not in KWARGS_IGNORED_IN_FORMATTING
)
return f"{name}({', '.join(args)})"
def label(self) -> str:
"""
Get the textual representation of the `Node` (for drawing).
Returns:
str:
textual representation of the `Node` (for drawing).
"""
if self.operation == Operation.Constant:
return format_constant(self(), maximum_length=45, keep_newlines=True)
if self.operation == Operation.Input:
return self.properties["name"]
assert_that(self.operation == Operation.Generic)
name = self.properties["name"]
if name == "index.static":
name = self.format(["index"])
return name
@property
def converted_to_table_lookup(self) -> bool:
"""
Get whether the node is converted to a table lookup during MLIR conversion.
Returns:
bool:
True if the node is converted to a table lookup, False otherwise
"""
return self.operation == Operation.Generic and self.properties["name"] not in [
"add",
"concatenate",
"conv2d",
"dot",
"index.static",
"matmul",
"multiply",
"negative",
"reshape",
"subtract",
"sum",
]

View File

@@ -0,0 +1,29 @@
"""
Declaration of `Operation` enum.
"""
from enum import Enum
class Operation(Enum):
"""
Operation enum, to distinguish nodes within a computation graph.
"""
# pylint: disable=invalid-name
Constant = "constant"
Generic = "generic"
Input = "input"
# pylint: enable=invalid-name
# https://graphviz.org/doc/info/colors.html#svg
OPERATION_COLOR_MAPPING = {
Operation.Constant: "grey",
Operation.Generic: "black",
Operation.Input: "crimson",
"output": "gold",
}

View File

@@ -0,0 +1,114 @@
"""
Declaration of various functions and constants related to representation of computation.
"""
from typing import Any, Dict, Hashable, Set, Union
import numpy as np
from ..internal.utils import assert_that
KWARGS_IGNORED_IN_FORMATTING: Set[str] = {
"subgraph",
"terminal_node",
}
SPECIAL_OBJECT_MAPPING: Dict[Any, str] = {
np.float16: "float16",
np.float32: "float32",
np.float64: "float64",
np.int8: "int8",
np.int16: "int16",
np.int32: "int32",
np.int64: "int64",
np.uint8: "uint8",
np.uint16: "uint16",
np.uint32: "uint32",
np.uint64: "uint64",
np.byte: "byte",
np.short: "short",
np.intc: "intc",
np.int_: "int_",
np.longlong: "longlong",
np.ubyte: "ubyte",
np.ushort: "ushort",
np.uintc: "uintc",
np.uint: "uint",
np.ulonglong: "ulonglong",
}
def format_constant(constant: Any, maximum_length: int = 45, keep_newlines: bool = False) -> str:
"""
Get the textual representation of a constant.
Args:
constant (Any):
constant to format
maximum_length (int, default = 45):
maximum length of the resulting string
keep_newlines (bool, default = False):
whether to keep newlines or not
Returns:
str:
textual representation of `constant`
"""
if isinstance(constant, Hashable) and constant in SPECIAL_OBJECT_MAPPING:
return SPECIAL_OBJECT_MAPPING[constant]
# maximum_length should not be smaller than 7 characters because
# the constant will be formatted to `x ... y`
# where x and y are part of the constant, and they are at least 1 character
assert_that(maximum_length >= 7)
result = str(constant)
if not keep_newlines:
result = result.replace("\n", "")
if len(result) > maximum_length:
from_start = (maximum_length - 5) // 2
from_end = (maximum_length - 5) - from_start
if keep_newlines and "\n" in result:
result = f"{result[:from_start]}\n...\n{result[-from_end:]}"
else:
result = f"{result[:from_start]} ... {result[-from_end:]}"
return result
def format_indexing_element(indexing_element: Union[int, np.integer, slice]):
"""
Format an indexing element.
This is required mainly for slices. The reason is that string representation of slices
are very long and verbose. To give an example, `x[:, 2:]` will have the following index
`[slice(None, None, None), slice(2, None, None)]` if printed naively. With this helper,
it will be formatted as `[:, 2:]`.
Args:
indexing_element (Union[int, np.integer, slice]):
indexing element to format
Returns:
str:
textual representation of `indexing_element`
"""
result = ""
if isinstance(indexing_element, slice):
if indexing_element.start is not None:
result += str(indexing_element.start)
result += ":"
if indexing_element.stop is not None:
result += str(indexing_element.stop)
if indexing_element.step is not None:
result += ":"
result += str(indexing_element.step)
else:
result += str(indexing_element)
return result.replace("\n", " ")