refactor(drawing): start using graphviz for visualization

This commit is contained in:
Umut
2021-08-20 17:09:19 +03:00
parent 2585eb7ed8
commit b41029d9c0
12 changed files with 186 additions and 248 deletions

View File

@@ -24,6 +24,8 @@ jobs:
steps:
- name: Install Git
run: apt-get install git -y
- name: Install Graphviz
run: apt-get install graphviz* -y
- name: Checkout Code
uses: actions/checkout@v2
with:

View File

@@ -1,6 +1,7 @@
FROM ghcr.io/zama-ai/zamalang-compiler
RUN apt-get install --no-install-recommends -y python3.8 python3.8-venv python-is-python3 git && \
RUN apt-get install --no-install-recommends -y \
python3.8 python3.8-venv python-is-python3 git graphviz* && \
pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir poetry && \
echo "source /hdk/.docker_venv/bin/activate" >> /root/.bashrc && \

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -72,9 +72,8 @@ class CompilationArtifacts:
draw_graph(
self.operation_graph,
show=False,
save_to=output_directory.joinpath("graph.png"),
block_until_user_closes_graph=False,
draw_edge_numbers=True,
)
if self.bounds is not None:

View File

@@ -1,10 +1,12 @@
"""functions to draw the different graphs we can generate in the package, eg to debug."""
import tempfile
from pathlib import Path
from typing import Dict, List, Optional
from typing import Optional
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
from ..operator_graph import OPGraph
from ..representation import intermediate as ir
@@ -22,222 +24,71 @@ IR_NODE_COLOR_MAPPING = {
}
def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float = 1.0) -> Dict:
"""Returns positions for graphs, to make them easy to read.
Returns a pos to be used later with eg nx.draw_networkx_nodes, so that nodes
are ordered by depth from input along the x axis and have a uniform
distribution along the y axis
Args:
graph (nx.Graph): The graph that we want to draw
x_delta (float): Parameter used to set the increment in x
y_delta (float): Parameter used to set the increment in y
Returns:
pos (Dict): the argument to use with eg nx.draw_networkx_nodes
"""
nodes_depth = {node: 0 for node in graph.nodes()}
input_nodes = [node for node in graph.nodes() if len(list(graph.predecessors(node))) == 0]
# Init a layout so that unreachable nodes have a pos, avoids potential crashes wiht networkx
# use a cheap layout
pos = nx.random_layout(graph)
curr_x = 0.0
curr_y = -(len(input_nodes) - 1) / 2 * y_delta
for in_node in input_nodes:
pos[in_node] = (curr_x, curr_y)
curr_y += y_delta
curr_x += x_delta
curr_nodes = input_nodes
current_depth = 0
while len(curr_nodes) > 0:
current_depth += 1
next_nodes_set = set()
for node in curr_nodes:
next_nodes_set.update(graph.successors(node))
curr_nodes = list(next_nodes_set)
for node in curr_nodes:
nodes_depth[node] = current_depth
nodes_by_depth: Dict[int, List[int]] = {}
for node, depth in nodes_depth.items():
nodes_for_depth = nodes_by_depth.get(depth, [])
nodes_for_depth.append(node)
nodes_by_depth[depth] = nodes_for_depth
depths = sorted(nodes_by_depth.keys())
for depth in depths:
nodes_at_depth = nodes_by_depth[depth]
curr_y = -(len(nodes_at_depth) - 1) / 2 * y_delta
for node in nodes_at_depth:
pos[node] = (curr_x, curr_y)
curr_y += y_delta
curr_x += x_delta
return pos
def adjust_limits():
"""Increases the limits of x and y axis of the current pyplot figure by 20%.
Returns:
None
"""
x_lim = plt.xlim()
x_distance = x_lim[1] - x_lim[0]
plt.xlim([x_lim[0] - x_distance / 10, x_lim[1] + x_distance / 10])
y_lim = plt.ylim()
y_distance = y_lim[1] - y_lim[0]
plt.ylim([y_lim[0] - y_distance / 10, y_lim[1] + y_distance / 10])
def draw_graph(
opgraph: OPGraph,
block_until_user_closes_graph: bool = True,
draw_edge_numbers: bool = True,
show: bool = False,
vertical: bool = True,
save_to: Optional[Path] = None,
) -> None:
"""Draw a graph.
) -> Image.Image:
"""Draws operation graphs and optionally saves/shows the drawing.
Args:
opgraph (OPGraph): The graph that we want to draw
block_until_user_closes_graph (bool): if True, will wait the user to
close the figure before continuing; False is useful for the CI tests
draw_edge_numbers (bool): if True, add the edge number on the arrow
linking nodes, eg to differentiate the x and y in a Sub coding
(x - y). This option is not that useful for commutative ops, and
may make the picture a bit too dense, so could be deactivated
save_to (Optional[Path]): if specified, the drawn graph will be saved
to this path
opgraph (OPGraph): the graph to be drawn and optionally saved/shown
show (bool): if set to True, the drawing will be shown using matplotlib
vertical (bool): if set to True, the orientation will be vertical
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path
Returns:
None
Pillow Image of the drawn graph.
This is useful because you can use the drawing however you like.
(check https://pillow.readthedocs.io/en/stable/reference/Image.html for further information)
"""
assert isinstance(opgraph, OPGraph)
set_of_nodes_which_are_outputs = set(opgraph.output_nodes.values())
graph = opgraph.graph
# Positions of the node
pos = human_readable_layout(graph)
# Colors and labels
def get_color(node):
def get_color(node, output_nodes):
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
if node in set_of_nodes_which_are_outputs:
if node in output_nodes:
value_to_return = IR_NODE_COLOR_MAPPING["output"]
elif isinstance(node, ir.ArbitraryFunction):
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
return value_to_return
color_map = [get_color(node) for node in graph.nodes()]
graph = opgraph.graph
output_nodes = set(opgraph.output_nodes.values())
# For most types, we just pick the operation as the label, but for Input,
# we take the name of the variable, ie the argument name of the function
# to compile
def get_proper_name(node):
if isinstance(node, ir.Input):
return node.input_name
if isinstance(node, ir.Constant):
return str(node.constant_data)
if isinstance(node, ir.ArbitraryFunction):
return node.op_name
return node.__class__.__name__
attributes = {
node: {
"label": node.label(),
"color": get_color(node, output_nodes),
"penwidth": 2, # double thickness for circles
"peripheries": 2 if node in output_nodes else 1, # double circle for output nodes
}
for node in graph.nodes
}
nx.set_node_attributes(graph, attributes)
label_dict = {node: get_proper_name(node) for node in graph.nodes()}
for edge in graph.edges(keys=True):
idx = graph.edges[edge]["input_idx"]
graph.edges[edge]["label"] = f" {idx} " # spaces are there intentionally for a better look
# Draw nodes
nx.draw_networkx_nodes(
graph,
pos,
node_color=color_map,
node_size=1000,
alpha=1,
)
agraph = nx.nx_agraph.to_agraph(graph)
agraph.graph_attr["rankdir"] = "TB" if vertical else "LR"
agraph.layout("dot")
# Draw labels
nx.draw_networkx_labels(graph, pos, labels=label_dict)
if save_to is None:
with tempfile.NamedTemporaryFile(suffix=".png") as tmp:
agraph.draw(tmp.name)
img = Image.open(tmp.name)
else:
agraph.draw(save_to)
img = Image.open(save_to)
current_axes = plt.gca()
if show: # pragma: no cover
# We can't have coverage in this branch as `plt.show()` blocks and waits for user action.
plt.close("all")
plt.figure()
plt.imshow(img)
plt.axis("off")
plt.show()
# And draw edges in a way which works when we have two "equivalent edges",
# ie from the same node A to the same node B, like to represent y = x + x
already_done = set()
for e in graph.edges:
# If we already drew the different edges from e[0] to e[1], continue
if (e[0], e[1]) in already_done:
continue
already_done.add((e[0], e[1]))
edges = graph.get_edge_data(e[0], e[1])
# Draw the different edges from e[0] to e[1], continue
for which, edge in enumerate(edges.values()):
edge_index = edge["input_idx"]
# Draw the edge
current_axes.annotate(
"",
xy=pos[e[0]],
xycoords="data",
xytext=pos[e[1]],
textcoords="data",
arrowprops=dict(
arrowstyle="<-",
color="0.5",
shrinkA=5,
shrinkB=5,
patchA=None,
patchB=None,
connectionstyle="arc3,rad=rrr".replace("rrr", str(0.3 * which)),
),
)
if draw_edge_numbers:
# Print the number of the node on the edge. This is a bit artisanal,
# since it seems not possible to add the text directly on the
# previously drawn arrow. So, more or less, we try to put a text at
# a position which is close to pos[e[1]] and which varies a bit with
# 'which'
a, b = pos[e[0]]
c, d = pos[e[1]]
const_0 = 1
const_1 = 2
current_axes.annotate(
str(edge_index),
xycoords="data",
xy=(
(const_0 * a + const_1 * c) / (const_0 + const_1),
(const_0 * b + const_1 * d + 0.1 * which) / (const_0 + const_1),
),
textcoords="data",
)
plt.axis("off")
adjust_limits()
# save the figure if requested
if save_to is not None:
plt.savefig(save_to)
# block_until_user_closes_graph is used as True for real users and False
# for CI
plt.show(block=block_until_user_closes_graph)
return img

View File

@@ -106,6 +106,15 @@ class IntermediateNode(ABC):
"""
return cls.n_in() > 1
@abstractmethod
def label(self) -> str:
"""Function to get the label of the node.
Returns:
str: the label of the node
"""
class Add(IntermediateNode):
"""Addition between two values."""
@@ -118,6 +127,9 @@ class Add(IntermediateNode):
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] + inputs[1]
def label(self) -> str:
return "+"
class Sub(IntermediateNode):
"""Subtraction between two values."""
@@ -130,6 +142,9 @@ class Sub(IntermediateNode):
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] - inputs[1]
def label(self) -> str:
return "-"
class Mul(IntermediateNode):
"""Multiplication between two values."""
@@ -142,6 +157,9 @@ class Mul(IntermediateNode):
def evaluate(self, inputs: Dict[int, Any]) -> Any:
return inputs[0] * inputs[1]
def label(self) -> str:
return "*"
class Input(IntermediateNode):
"""Node representing an input of the program."""
@@ -173,6 +191,9 @@ class Input(IntermediateNode):
and super().is_equivalent_to(other)
)
def label(self) -> str:
return self.input_name
class Constant(IntermediateNode):
"""Node representing a constant of the program."""
@@ -213,6 +234,9 @@ class Constant(IntermediateNode):
"""
return self._constant_data
def label(self) -> str:
return str(self.constant_data)
class ArbitraryFunction(IntermediateNode):
"""Node representing a univariate arbitrary function, e.g. sin(x)."""
@@ -257,3 +281,6 @@ class ArbitraryFunction(IntermediateNode):
and self.op_name == other.op_name
and super().is_equivalent_to(other)
)
def label(self) -> str:
return self.op_name

View File

@@ -129,7 +129,7 @@ class NPTracer(BaseTracer):
arbitrary_func=numpy.rint,
output_dtype=common_output_dtypes[0],
op_kwargs=deepcopy(kwargs),
op_name="numpy.rint",
op_name="np.rint",
)
output_tracer = self.__class__(
input_tracers, traced_computation=traced_computation, output_index=0
@@ -151,7 +151,7 @@ class NPTracer(BaseTracer):
arbitrary_func=numpy.sin,
output_dtype=common_output_dtypes[0],
op_kwargs=deepcopy(kwargs),
op_name="numpy.sin",
op_name="np.sin",
)
output_tracer = self.__class__(
input_tracers, traced_computation=traced_computation, output_index=0

44
poetry.lock generated
View File

@@ -229,7 +229,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "diff-cover"
version = "6.3.1"
version = "6.3.3"
description = "Run coverage and linting reports on diffs"
category = "dev"
optional = false
@@ -511,13 +511,14 @@ qtconsole = "*"
[[package]]
name = "jupyter-client"
version = "6.2.0"
version = "7.0.1"
description = "Jupyter protocol implementation and client libraries"
category = "dev"
optional = false
python-versions = ">=3.6.1"
[package.dependencies]
entrypoints = "*"
jupyter-core = ">=4.6.0"
nest-asyncio = ">=1.5"
python-dateutil = ">=2.1"
@@ -526,8 +527,8 @@ tornado = ">=4.1"
traitlets = "*"
[package.extras]
doc = ["sphinx (>=1.3.6)", "sphinx-rtd-theme", "sphinxcontrib-github-alt"]
test = ["async-generator", "ipykernel", "ipython", "mock", "pytest-asyncio", "pytest-timeout", "pytest", "mypy", "pre-commit", "jedi (<0.18)"]
doc = ["myst-parser", "sphinx (>=1.3.6)", "sphinx-rtd-theme", "sphinxcontrib-github-alt"]
test = ["codecov", "coverage", "ipykernel", "ipython", "mock", "mypy", "pre-commit", "pytest", "pytest-asyncio", "pytest-cov", "pytest-timeout", "jedi (<0.18)"]
[[package]]
name = "jupyter-console"
@@ -974,11 +975,11 @@ twisted = ["twisted"]
[[package]]
name = "prompt-toolkit"
version = "3.0.19"
version = "3.0.20"
description = "Library for building powerful interactive command lines in Python"
category = "dev"
optional = false
python-versions = ">=3.6.1"
python-versions = ">=3.6.2"
[package.dependencies]
wcwidth = "*"
@@ -1068,6 +1069,14 @@ category = "dev"
optional = false
python-versions = ">=3.5"
[[package]]
name = "pygraphviz"
version = "1.7"
description = "Python interface to Graphviz"
category = "main"
optional = false
python-versions = ">=3.7"
[[package]]
name = "pylint"
version = "2.9.6"
@@ -1413,7 +1422,7 @@ test = ["pytest"]
[[package]]
name = "terminado"
version = "0.11.0"
version = "0.11.1"
description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
category = "dev"
optional = false
@@ -1555,7 +1564,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes
[metadata]
lock-version = "1.1"
python-versions = ">=3.7,<3.10"
content-hash = "65489a7f8c03f8825d0948ffec2ef5809e53d82e5b9eb3c77d5fa512c175d0fd"
content-hash = "382f9225cb89e407123521c4a9ec5aedc021701f7af8ef79e2959ca5996a6be1"
[metadata.files]
alabaster = [
@@ -1756,8 +1765,8 @@ defusedxml = [
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
]
diff-cover = [
{file = "diff_cover-6.3.1-py3-none-any.whl", hash = "sha256:2578fb51c4a5ce162d9ba7f5dcc28132b55539c889e9b648f9df77d4fdcf8fb4"},
{file = "diff_cover-6.3.1.tar.gz", hash = "sha256:21baf9d6f40ef352df4adf19b5bb4d47249c540a648fecda4647a41ff558d47c"},
{file = "diff_cover-6.3.3-py3-none-any.whl", hash = "sha256:4aaffc7051dd6b0e4e39170d2a69f412a21bbbf8497c85654a8d0c1fd44be534"},
{file = "diff_cover-6.3.3.tar.gz", hash = "sha256:487b9babf6d1a7d73b9f72c2ee4cbed2840bf2f0e203e184b9ef632532115665"},
]
docutils = [
{file = "docutils-0.16-py2.py3-none-any.whl", hash = "sha256:0c5b78adfbf7762415433f5515cd5c9e762339e23369dbe8000d84a4bf4ab3af"},
@@ -1837,8 +1846,8 @@ jupyter = [
{file = "jupyter-1.0.0.zip", hash = "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7"},
]
jupyter-client = [
{file = "jupyter_client-6.2.0-py3-none-any.whl", hash = "sha256:9715152067e3f7ea3b56f341c9a0f9715c8c7cc316ee0eb13c3c84f5ca0065f5"},
{file = "jupyter_client-6.2.0.tar.gz", hash = "sha256:e2ab61d79fbf8b56734a4c2499f19830fbd7f6fefb3e87868ef0545cb3c17eb9"},
{file = "jupyter_client-7.0.1-py3-none-any.whl", hash = "sha256:07b9566979546004c089afe7c9bf9e96224ec5f8421fe0ae460759fa593c6b1d"},
{file = "jupyter_client-7.0.1.tar.gz", hash = "sha256:48822a93d9d75daa5fde235c35cf7a92fc979384735962501d4eb60b197fb43a"},
]
jupyter-console = [
{file = "jupyter_console-6.4.0-py3-none-any.whl", hash = "sha256:7799c4ea951e0e96ba8260575423cb323ea5a03fcf5503560fa3e15748869e27"},
@@ -2182,8 +2191,8 @@ prometheus-client = [
{file = "prometheus_client-0.11.0.tar.gz", hash = "sha256:3a8baade6cb80bcfe43297e33e7623f3118d660d41387593758e2fb1ea173a86"},
]
prompt-toolkit = [
{file = "prompt_toolkit-3.0.19-py3-none-any.whl", hash = "sha256:7089d8d2938043508aa9420ec18ce0922885304cddae87fb96eebca942299f88"},
{file = "prompt_toolkit-3.0.19.tar.gz", hash = "sha256:08360ee3a3148bdb5163621709ee322ec34fc4375099afa4bbf751e9b7b7fa4f"},
{file = "prompt_toolkit-3.0.20-py3-none-any.whl", hash = "sha256:6076e46efae19b1e0ca1ec003ed37a933dc94b4d20f486235d436e64771dcd5c"},
{file = "prompt_toolkit-3.0.20.tar.gz", hash = "sha256:eb71d5a6b72ce6db177af4a7d4d7085b99756bf656d98ffcc4fecd36850eea6c"},
]
ptyprocess = [
{file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
@@ -2240,6 +2249,9 @@ pygments = [
{file = "Pygments-2.10.0-py3-none-any.whl", hash = "sha256:b8e67fe6af78f492b3c4b3e2970c0624cbf08beb1e493b2c99b9fa1b67a20380"},
{file = "Pygments-2.10.0.tar.gz", hash = "sha256:f398865f7eb6874156579fdf36bc840a03cab64d1cde9e93d68f46a425ec52c6"},
]
pygraphviz = [
{file = "pygraphviz-1.7.zip", hash = "sha256:a7bec6609f37cf1e64898c59f075afd659106cf9356c5f387cecaa2e0cdb2304"},
]
pylint = [
{file = "pylint-2.9.6-py3-none-any.whl", hash = "sha256:2e1a0eb2e8ab41d6b5dbada87f066492bb1557b12b76c47c2ee8aa8a11186594"},
{file = "pylint-2.9.6.tar.gz", hash = "sha256:8b838c8983ee1904b2de66cce9d0b96649a91901350e956d78f289c3bc87b48e"},
@@ -2472,8 +2484,8 @@ sphinxcontrib-serializinghtml = [
{file = "sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl", hash = "sha256:352a9a00ae864471d3a7ead8d7d79f5fc0b57e8b3f95e9867eb9eb28999b92fd"},
]
terminado = [
{file = "terminado-0.11.0-py3-none-any.whl", hash = "sha256:221eef83e6a504894842f7dccfa971ca2e98ec22a8a9118577e5257527674b42"},
{file = "terminado-0.11.0.tar.gz", hash = "sha256:1e01183885f64c1bba3cf89a5a995ad4acfed4e5f00aebcce1bf7f089b0825a1"},
{file = "terminado-0.11.1-py3-none-any.whl", hash = "sha256:9e0457334863be3e6060c487ad60e0995fa1df54f109c67b24ff49a4f2f34df5"},
{file = "terminado-0.11.1.tar.gz", hash = "sha256:962b402edbb480718054dc37027bada293972ecadfb587b89f01e2b8660a2132"},
]
testpath = [
{file = "testpath-0.5.0-py3-none-any.whl", hash = "sha256:8044f9a0bab6567fc644a3593164e872543bb44225b0e24846e2c89237937589"},

View File

@@ -9,6 +9,8 @@ python = ">=3.7,<3.10"
networkx = "^2.6.1"
matplotlib = "^3.4.2"
numpy = "^1.21.1"
pygraphviz = "^1.7"
Pillow = "^8.3.1"
[tool.poetry.dev-dependencies]
isort = "^5.9.2"

View File

@@ -61,7 +61,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
# TODO: For the moment, we don't have really checks, but some printfs. Later,
# when we have the converter, we can check the MLIR
draw_graph(op_graph, block_until_user_closes_graph=False)
draw_graph(op_graph, show=False)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")

View File

@@ -137,7 +137,7 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
x, y = x_y
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
draw_graph(graph, block_until_user_closes_graph=False)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph)
@@ -167,7 +167,7 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph
"Test hnumpy get_printable_graph and draw_graph on graphs with direct table lookup"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, block_until_user_closes_graph=False)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph)
@@ -257,7 +257,7 @@ def test_hnumpy_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref
"""Test hnumpy get_printable_graph with show_data_types on graphs with direct table lookup"""
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, block_until_user_closes_graph=False)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph, show_data_types=True)