mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor(drawing): start using graphviz for visualization
This commit is contained in:
@@ -24,6 +24,8 @@ jobs:
|
||||
steps:
|
||||
- name: Install Git
|
||||
run: apt-get install git -y
|
||||
- name: Install Graphviz
|
||||
run: apt-get install graphviz* -y
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
FROM ghcr.io/zama-ai/zamalang-compiler
|
||||
|
||||
RUN apt-get install --no-install-recommends -y python3.8 python3.8-venv python-is-python3 git && \
|
||||
RUN apt-get install --no-install-recommends -y \
|
||||
python3.8 python3.8-venv python-is-python3 git graphviz* && \
|
||||
pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir poetry && \
|
||||
echo "source /hdk/.docker_venv/bin/activate" >> /root/.bashrc && \
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -72,9 +72,8 @@ class CompilationArtifacts:
|
||||
|
||||
draw_graph(
|
||||
self.operation_graph,
|
||||
show=False,
|
||||
save_to=output_directory.joinpath("graph.png"),
|
||||
block_until_user_closes_graph=False,
|
||||
draw_edge_numbers=True,
|
||||
)
|
||||
|
||||
if self.bounds is not None:
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""functions to draw the different graphs we can generate in the package, eg to debug."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from PIL import Image
|
||||
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
@@ -22,222 +24,71 @@ IR_NODE_COLOR_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float = 1.0) -> Dict:
|
||||
"""Returns positions for graphs, to make them easy to read.
|
||||
|
||||
Returns a pos to be used later with eg nx.draw_networkx_nodes, so that nodes
|
||||
are ordered by depth from input along the x axis and have a uniform
|
||||
distribution along the y axis
|
||||
|
||||
Args:
|
||||
graph (nx.Graph): The graph that we want to draw
|
||||
x_delta (float): Parameter used to set the increment in x
|
||||
y_delta (float): Parameter used to set the increment in y
|
||||
|
||||
Returns:
|
||||
pos (Dict): the argument to use with eg nx.draw_networkx_nodes
|
||||
|
||||
"""
|
||||
nodes_depth = {node: 0 for node in graph.nodes()}
|
||||
input_nodes = [node for node in graph.nodes() if len(list(graph.predecessors(node))) == 0]
|
||||
|
||||
# Init a layout so that unreachable nodes have a pos, avoids potential crashes wiht networkx
|
||||
# use a cheap layout
|
||||
pos = nx.random_layout(graph)
|
||||
|
||||
curr_x = 0.0
|
||||
curr_y = -(len(input_nodes) - 1) / 2 * y_delta
|
||||
|
||||
for in_node in input_nodes:
|
||||
pos[in_node] = (curr_x, curr_y)
|
||||
curr_y += y_delta
|
||||
|
||||
curr_x += x_delta
|
||||
|
||||
curr_nodes = input_nodes
|
||||
|
||||
current_depth = 0
|
||||
while len(curr_nodes) > 0:
|
||||
current_depth += 1
|
||||
next_nodes_set = set()
|
||||
for node in curr_nodes:
|
||||
next_nodes_set.update(graph.successors(node))
|
||||
|
||||
curr_nodes = list(next_nodes_set)
|
||||
for node in curr_nodes:
|
||||
nodes_depth[node] = current_depth
|
||||
|
||||
nodes_by_depth: Dict[int, List[int]] = {}
|
||||
for node, depth in nodes_depth.items():
|
||||
nodes_for_depth = nodes_by_depth.get(depth, [])
|
||||
nodes_for_depth.append(node)
|
||||
nodes_by_depth[depth] = nodes_for_depth
|
||||
|
||||
depths = sorted(nodes_by_depth.keys())
|
||||
|
||||
for depth in depths:
|
||||
nodes_at_depth = nodes_by_depth[depth]
|
||||
|
||||
curr_y = -(len(nodes_at_depth) - 1) / 2 * y_delta
|
||||
for node in nodes_at_depth:
|
||||
pos[node] = (curr_x, curr_y)
|
||||
curr_y += y_delta
|
||||
|
||||
curr_x += x_delta
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def adjust_limits():
|
||||
"""Increases the limits of x and y axis of the current pyplot figure by 20%.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
x_lim = plt.xlim()
|
||||
x_distance = x_lim[1] - x_lim[0]
|
||||
plt.xlim([x_lim[0] - x_distance / 10, x_lim[1] + x_distance / 10])
|
||||
|
||||
y_lim = plt.ylim()
|
||||
y_distance = y_lim[1] - y_lim[0]
|
||||
plt.ylim([y_lim[0] - y_distance / 10, y_lim[1] + y_distance / 10])
|
||||
|
||||
|
||||
def draw_graph(
|
||||
opgraph: OPGraph,
|
||||
block_until_user_closes_graph: bool = True,
|
||||
draw_edge_numbers: bool = True,
|
||||
show: bool = False,
|
||||
vertical: bool = True,
|
||||
save_to: Optional[Path] = None,
|
||||
) -> None:
|
||||
"""Draw a graph.
|
||||
) -> Image.Image:
|
||||
"""Draws operation graphs and optionally saves/shows the drawing.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph): The graph that we want to draw
|
||||
block_until_user_closes_graph (bool): if True, will wait the user to
|
||||
close the figure before continuing; False is useful for the CI tests
|
||||
draw_edge_numbers (bool): if True, add the edge number on the arrow
|
||||
linking nodes, eg to differentiate the x and y in a Sub coding
|
||||
(x - y). This option is not that useful for commutative ops, and
|
||||
may make the picture a bit too dense, so could be deactivated
|
||||
save_to (Optional[Path]): if specified, the drawn graph will be saved
|
||||
to this path
|
||||
opgraph (OPGraph): the graph to be drawn and optionally saved/shown
|
||||
show (bool): if set to True, the drawing will be shown using matplotlib
|
||||
vertical (bool): if set to True, the orientation will be vertical
|
||||
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path
|
||||
|
||||
Returns:
|
||||
None
|
||||
Pillow Image of the drawn graph.
|
||||
This is useful because you can use the drawing however you like.
|
||||
(check https://pillow.readthedocs.io/en/stable/reference/Image.html for further information)
|
||||
|
||||
"""
|
||||
assert isinstance(opgraph, OPGraph)
|
||||
set_of_nodes_which_are_outputs = set(opgraph.output_nodes.values())
|
||||
graph = opgraph.graph
|
||||
|
||||
# Positions of the node
|
||||
pos = human_readable_layout(graph)
|
||||
|
||||
# Colors and labels
|
||||
def get_color(node):
|
||||
def get_color(node, output_nodes):
|
||||
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
|
||||
if node in set_of_nodes_which_are_outputs:
|
||||
if node in output_nodes:
|
||||
value_to_return = IR_NODE_COLOR_MAPPING["output"]
|
||||
elif isinstance(node, ir.ArbitraryFunction):
|
||||
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
|
||||
return value_to_return
|
||||
|
||||
color_map = [get_color(node) for node in graph.nodes()]
|
||||
graph = opgraph.graph
|
||||
output_nodes = set(opgraph.output_nodes.values())
|
||||
|
||||
# For most types, we just pick the operation as the label, but for Input,
|
||||
# we take the name of the variable, ie the argument name of the function
|
||||
# to compile
|
||||
def get_proper_name(node):
|
||||
if isinstance(node, ir.Input):
|
||||
return node.input_name
|
||||
if isinstance(node, ir.Constant):
|
||||
return str(node.constant_data)
|
||||
if isinstance(node, ir.ArbitraryFunction):
|
||||
return node.op_name
|
||||
return node.__class__.__name__
|
||||
attributes = {
|
||||
node: {
|
||||
"label": node.label(),
|
||||
"color": get_color(node, output_nodes),
|
||||
"penwidth": 2, # double thickness for circles
|
||||
"peripheries": 2 if node in output_nodes else 1, # double circle for output nodes
|
||||
}
|
||||
for node in graph.nodes
|
||||
}
|
||||
nx.set_node_attributes(graph, attributes)
|
||||
|
||||
label_dict = {node: get_proper_name(node) for node in graph.nodes()}
|
||||
for edge in graph.edges(keys=True):
|
||||
idx = graph.edges[edge]["input_idx"]
|
||||
graph.edges[edge]["label"] = f" {idx} " # spaces are there intentionally for a better look
|
||||
|
||||
# Draw nodes
|
||||
nx.draw_networkx_nodes(
|
||||
graph,
|
||||
pos,
|
||||
node_color=color_map,
|
||||
node_size=1000,
|
||||
alpha=1,
|
||||
)
|
||||
agraph = nx.nx_agraph.to_agraph(graph)
|
||||
agraph.graph_attr["rankdir"] = "TB" if vertical else "LR"
|
||||
agraph.layout("dot")
|
||||
|
||||
# Draw labels
|
||||
nx.draw_networkx_labels(graph, pos, labels=label_dict)
|
||||
if save_to is None:
|
||||
with tempfile.NamedTemporaryFile(suffix=".png") as tmp:
|
||||
agraph.draw(tmp.name)
|
||||
img = Image.open(tmp.name)
|
||||
else:
|
||||
agraph.draw(save_to)
|
||||
img = Image.open(save_to)
|
||||
|
||||
current_axes = plt.gca()
|
||||
if show: # pragma: no cover
|
||||
# We can't have coverage in this branch as `plt.show()` blocks and waits for user action.
|
||||
plt.close("all")
|
||||
plt.figure()
|
||||
plt.imshow(img)
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
# And draw edges in a way which works when we have two "equivalent edges",
|
||||
# ie from the same node A to the same node B, like to represent y = x + x
|
||||
already_done = set()
|
||||
|
||||
for e in graph.edges:
|
||||
|
||||
# If we already drew the different edges from e[0] to e[1], continue
|
||||
if (e[0], e[1]) in already_done:
|
||||
continue
|
||||
|
||||
already_done.add((e[0], e[1]))
|
||||
|
||||
edges = graph.get_edge_data(e[0], e[1])
|
||||
|
||||
# Draw the different edges from e[0] to e[1], continue
|
||||
for which, edge in enumerate(edges.values()):
|
||||
edge_index = edge["input_idx"]
|
||||
|
||||
# Draw the edge
|
||||
current_axes.annotate(
|
||||
"",
|
||||
xy=pos[e[0]],
|
||||
xycoords="data",
|
||||
xytext=pos[e[1]],
|
||||
textcoords="data",
|
||||
arrowprops=dict(
|
||||
arrowstyle="<-",
|
||||
color="0.5",
|
||||
shrinkA=5,
|
||||
shrinkB=5,
|
||||
patchA=None,
|
||||
patchB=None,
|
||||
connectionstyle="arc3,rad=rrr".replace("rrr", str(0.3 * which)),
|
||||
),
|
||||
)
|
||||
|
||||
if draw_edge_numbers:
|
||||
# Print the number of the node on the edge. This is a bit artisanal,
|
||||
# since it seems not possible to add the text directly on the
|
||||
# previously drawn arrow. So, more or less, we try to put a text at
|
||||
# a position which is close to pos[e[1]] and which varies a bit with
|
||||
# 'which'
|
||||
a, b = pos[e[0]]
|
||||
c, d = pos[e[1]]
|
||||
const_0 = 1
|
||||
const_1 = 2
|
||||
|
||||
current_axes.annotate(
|
||||
str(edge_index),
|
||||
xycoords="data",
|
||||
xy=(
|
||||
(const_0 * a + const_1 * c) / (const_0 + const_1),
|
||||
(const_0 * b + const_1 * d + 0.1 * which) / (const_0 + const_1),
|
||||
),
|
||||
textcoords="data",
|
||||
)
|
||||
|
||||
plt.axis("off")
|
||||
|
||||
adjust_limits()
|
||||
|
||||
# save the figure if requested
|
||||
if save_to is not None:
|
||||
plt.savefig(save_to)
|
||||
|
||||
# block_until_user_closes_graph is used as True for real users and False
|
||||
# for CI
|
||||
plt.show(block=block_until_user_closes_graph)
|
||||
return img
|
||||
|
||||
@@ -106,6 +106,15 @@ class IntermediateNode(ABC):
|
||||
"""
|
||||
return cls.n_in() > 1
|
||||
|
||||
@abstractmethod
|
||||
def label(self) -> str:
|
||||
"""Function to get the label of the node.
|
||||
|
||||
Returns:
|
||||
str: the label of the node
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class Add(IntermediateNode):
|
||||
"""Addition between two values."""
|
||||
@@ -118,6 +127,9 @@ class Add(IntermediateNode):
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] + inputs[1]
|
||||
|
||||
def label(self) -> str:
|
||||
return "+"
|
||||
|
||||
|
||||
class Sub(IntermediateNode):
|
||||
"""Subtraction between two values."""
|
||||
@@ -130,6 +142,9 @@ class Sub(IntermediateNode):
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] - inputs[1]
|
||||
|
||||
def label(self) -> str:
|
||||
return "-"
|
||||
|
||||
|
||||
class Mul(IntermediateNode):
|
||||
"""Multiplication between two values."""
|
||||
@@ -142,6 +157,9 @@ class Mul(IntermediateNode):
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Any:
|
||||
return inputs[0] * inputs[1]
|
||||
|
||||
def label(self) -> str:
|
||||
return "*"
|
||||
|
||||
|
||||
class Input(IntermediateNode):
|
||||
"""Node representing an input of the program."""
|
||||
@@ -173,6 +191,9 @@ class Input(IntermediateNode):
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
def label(self) -> str:
|
||||
return self.input_name
|
||||
|
||||
|
||||
class Constant(IntermediateNode):
|
||||
"""Node representing a constant of the program."""
|
||||
@@ -213,6 +234,9 @@ class Constant(IntermediateNode):
|
||||
"""
|
||||
return self._constant_data
|
||||
|
||||
def label(self) -> str:
|
||||
return str(self.constant_data)
|
||||
|
||||
|
||||
class ArbitraryFunction(IntermediateNode):
|
||||
"""Node representing a univariate arbitrary function, e.g. sin(x)."""
|
||||
@@ -257,3 +281,6 @@ class ArbitraryFunction(IntermediateNode):
|
||||
and self.op_name == other.op_name
|
||||
and super().is_equivalent_to(other)
|
||||
)
|
||||
|
||||
def label(self) -> str:
|
||||
return self.op_name
|
||||
|
||||
@@ -129,7 +129,7 @@ class NPTracer(BaseTracer):
|
||||
arbitrary_func=numpy.rint,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="numpy.rint",
|
||||
op_name="np.rint",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
input_tracers, traced_computation=traced_computation, output_index=0
|
||||
@@ -151,7 +151,7 @@ class NPTracer(BaseTracer):
|
||||
arbitrary_func=numpy.sin,
|
||||
output_dtype=common_output_dtypes[0],
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="numpy.sin",
|
||||
op_name="np.sin",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
input_tracers, traced_computation=traced_computation, output_index=0
|
||||
|
||||
44
poetry.lock
generated
44
poetry.lock
generated
@@ -229,7 +229,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
|
||||
[[package]]
|
||||
name = "diff-cover"
|
||||
version = "6.3.1"
|
||||
version = "6.3.3"
|
||||
description = "Run coverage and linting reports on diffs"
|
||||
category = "dev"
|
||||
optional = false
|
||||
@@ -511,13 +511,14 @@ qtconsole = "*"
|
||||
|
||||
[[package]]
|
||||
name = "jupyter-client"
|
||||
version = "6.2.0"
|
||||
version = "7.0.1"
|
||||
description = "Jupyter protocol implementation and client libraries"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6.1"
|
||||
|
||||
[package.dependencies]
|
||||
entrypoints = "*"
|
||||
jupyter-core = ">=4.6.0"
|
||||
nest-asyncio = ">=1.5"
|
||||
python-dateutil = ">=2.1"
|
||||
@@ -526,8 +527,8 @@ tornado = ">=4.1"
|
||||
traitlets = "*"
|
||||
|
||||
[package.extras]
|
||||
doc = ["sphinx (>=1.3.6)", "sphinx-rtd-theme", "sphinxcontrib-github-alt"]
|
||||
test = ["async-generator", "ipykernel", "ipython", "mock", "pytest-asyncio", "pytest-timeout", "pytest", "mypy", "pre-commit", "jedi (<0.18)"]
|
||||
doc = ["myst-parser", "sphinx (>=1.3.6)", "sphinx-rtd-theme", "sphinxcontrib-github-alt"]
|
||||
test = ["codecov", "coverage", "ipykernel", "ipython", "mock", "mypy", "pre-commit", "pytest", "pytest-asyncio", "pytest-cov", "pytest-timeout", "jedi (<0.18)"]
|
||||
|
||||
[[package]]
|
||||
name = "jupyter-console"
|
||||
@@ -974,11 +975,11 @@ twisted = ["twisted"]
|
||||
|
||||
[[package]]
|
||||
name = "prompt-toolkit"
|
||||
version = "3.0.19"
|
||||
version = "3.0.20"
|
||||
description = "Library for building powerful interactive command lines in Python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6.1"
|
||||
python-versions = ">=3.6.2"
|
||||
|
||||
[package.dependencies]
|
||||
wcwidth = "*"
|
||||
@@ -1068,6 +1069,14 @@ category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
|
||||
[[package]]
|
||||
name = "pygraphviz"
|
||||
version = "1.7"
|
||||
description = "Python interface to Graphviz"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[[package]]
|
||||
name = "pylint"
|
||||
version = "2.9.6"
|
||||
@@ -1413,7 +1422,7 @@ test = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "terminado"
|
||||
version = "0.11.0"
|
||||
version = "0.11.1"
|
||||
description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
|
||||
category = "dev"
|
||||
optional = false
|
||||
@@ -1555,7 +1564,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.7,<3.10"
|
||||
content-hash = "65489a7f8c03f8825d0948ffec2ef5809e53d82e5b9eb3c77d5fa512c175d0fd"
|
||||
content-hash = "382f9225cb89e407123521c4a9ec5aedc021701f7af8ef79e2959ca5996a6be1"
|
||||
|
||||
[metadata.files]
|
||||
alabaster = [
|
||||
@@ -1756,8 +1765,8 @@ defusedxml = [
|
||||
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
|
||||
]
|
||||
diff-cover = [
|
||||
{file = "diff_cover-6.3.1-py3-none-any.whl", hash = "sha256:2578fb51c4a5ce162d9ba7f5dcc28132b55539c889e9b648f9df77d4fdcf8fb4"},
|
||||
{file = "diff_cover-6.3.1.tar.gz", hash = "sha256:21baf9d6f40ef352df4adf19b5bb4d47249c540a648fecda4647a41ff558d47c"},
|
||||
{file = "diff_cover-6.3.3-py3-none-any.whl", hash = "sha256:4aaffc7051dd6b0e4e39170d2a69f412a21bbbf8497c85654a8d0c1fd44be534"},
|
||||
{file = "diff_cover-6.3.3.tar.gz", hash = "sha256:487b9babf6d1a7d73b9f72c2ee4cbed2840bf2f0e203e184b9ef632532115665"},
|
||||
]
|
||||
docutils = [
|
||||
{file = "docutils-0.16-py2.py3-none-any.whl", hash = "sha256:0c5b78adfbf7762415433f5515cd5c9e762339e23369dbe8000d84a4bf4ab3af"},
|
||||
@@ -1837,8 +1846,8 @@ jupyter = [
|
||||
{file = "jupyter-1.0.0.zip", hash = "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7"},
|
||||
]
|
||||
jupyter-client = [
|
||||
{file = "jupyter_client-6.2.0-py3-none-any.whl", hash = "sha256:9715152067e3f7ea3b56f341c9a0f9715c8c7cc316ee0eb13c3c84f5ca0065f5"},
|
||||
{file = "jupyter_client-6.2.0.tar.gz", hash = "sha256:e2ab61d79fbf8b56734a4c2499f19830fbd7f6fefb3e87868ef0545cb3c17eb9"},
|
||||
{file = "jupyter_client-7.0.1-py3-none-any.whl", hash = "sha256:07b9566979546004c089afe7c9bf9e96224ec5f8421fe0ae460759fa593c6b1d"},
|
||||
{file = "jupyter_client-7.0.1.tar.gz", hash = "sha256:48822a93d9d75daa5fde235c35cf7a92fc979384735962501d4eb60b197fb43a"},
|
||||
]
|
||||
jupyter-console = [
|
||||
{file = "jupyter_console-6.4.0-py3-none-any.whl", hash = "sha256:7799c4ea951e0e96ba8260575423cb323ea5a03fcf5503560fa3e15748869e27"},
|
||||
@@ -2182,8 +2191,8 @@ prometheus-client = [
|
||||
{file = "prometheus_client-0.11.0.tar.gz", hash = "sha256:3a8baade6cb80bcfe43297e33e7623f3118d660d41387593758e2fb1ea173a86"},
|
||||
]
|
||||
prompt-toolkit = [
|
||||
{file = "prompt_toolkit-3.0.19-py3-none-any.whl", hash = "sha256:7089d8d2938043508aa9420ec18ce0922885304cddae87fb96eebca942299f88"},
|
||||
{file = "prompt_toolkit-3.0.19.tar.gz", hash = "sha256:08360ee3a3148bdb5163621709ee322ec34fc4375099afa4bbf751e9b7b7fa4f"},
|
||||
{file = "prompt_toolkit-3.0.20-py3-none-any.whl", hash = "sha256:6076e46efae19b1e0ca1ec003ed37a933dc94b4d20f486235d436e64771dcd5c"},
|
||||
{file = "prompt_toolkit-3.0.20.tar.gz", hash = "sha256:eb71d5a6b72ce6db177af4a7d4d7085b99756bf656d98ffcc4fecd36850eea6c"},
|
||||
]
|
||||
ptyprocess = [
|
||||
{file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
|
||||
@@ -2240,6 +2249,9 @@ pygments = [
|
||||
{file = "Pygments-2.10.0-py3-none-any.whl", hash = "sha256:b8e67fe6af78f492b3c4b3e2970c0624cbf08beb1e493b2c99b9fa1b67a20380"},
|
||||
{file = "Pygments-2.10.0.tar.gz", hash = "sha256:f398865f7eb6874156579fdf36bc840a03cab64d1cde9e93d68f46a425ec52c6"},
|
||||
]
|
||||
pygraphviz = [
|
||||
{file = "pygraphviz-1.7.zip", hash = "sha256:a7bec6609f37cf1e64898c59f075afd659106cf9356c5f387cecaa2e0cdb2304"},
|
||||
]
|
||||
pylint = [
|
||||
{file = "pylint-2.9.6-py3-none-any.whl", hash = "sha256:2e1a0eb2e8ab41d6b5dbada87f066492bb1557b12b76c47c2ee8aa8a11186594"},
|
||||
{file = "pylint-2.9.6.tar.gz", hash = "sha256:8b838c8983ee1904b2de66cce9d0b96649a91901350e956d78f289c3bc87b48e"},
|
||||
@@ -2472,8 +2484,8 @@ sphinxcontrib-serializinghtml = [
|
||||
{file = "sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl", hash = "sha256:352a9a00ae864471d3a7ead8d7d79f5fc0b57e8b3f95e9867eb9eb28999b92fd"},
|
||||
]
|
||||
terminado = [
|
||||
{file = "terminado-0.11.0-py3-none-any.whl", hash = "sha256:221eef83e6a504894842f7dccfa971ca2e98ec22a8a9118577e5257527674b42"},
|
||||
{file = "terminado-0.11.0.tar.gz", hash = "sha256:1e01183885f64c1bba3cf89a5a995ad4acfed4e5f00aebcce1bf7f089b0825a1"},
|
||||
{file = "terminado-0.11.1-py3-none-any.whl", hash = "sha256:9e0457334863be3e6060c487ad60e0995fa1df54f109c67b24ff49a4f2f34df5"},
|
||||
{file = "terminado-0.11.1.tar.gz", hash = "sha256:962b402edbb480718054dc37027bada293972ecadfb587b89f01e2b8660a2132"},
|
||||
]
|
||||
testpath = [
|
||||
{file = "testpath-0.5.0-py3-none-any.whl", hash = "sha256:8044f9a0bab6567fc644a3593164e872543bb44225b0e24846e2c89237937589"},
|
||||
|
||||
@@ -9,6 +9,8 @@ python = ">=3.7,<3.10"
|
||||
networkx = "^2.6.1"
|
||||
matplotlib = "^3.4.2"
|
||||
numpy = "^1.21.1"
|
||||
pygraphviz = "^1.7"
|
||||
Pillow = "^8.3.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
isort = "^5.9.2"
|
||||
|
||||
@@ -61,7 +61,7 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
|
||||
# TODO: For the moment, we don't have really checks, but some printfs. Later,
|
||||
# when we have the converter, we can check the MLIR
|
||||
draw_graph(op_graph, block_until_user_closes_graph=False)
|
||||
draw_graph(op_graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
@@ -137,7 +137,7 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
draw_graph(graph, block_until_user_closes_graph=False)
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
|
||||
@@ -167,7 +167,7 @@ def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph
|
||||
"Test hnumpy get_printable_graph and draw_graph on graphs with direct table lookup"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, block_until_user_closes_graph=False)
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
|
||||
@@ -257,7 +257,7 @@ def test_hnumpy_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref
|
||||
"""Test hnumpy get_printable_graph with show_data_types on graphs with direct table lookup"""
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, block_until_user_closes_graph=False)
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user