Files
concrete/hdk/common/compilation/artifacts.py
2021-09-01 10:23:14 +02:00

208 lines
7.6 KiB
Python

"""Module for compilation artifacts."""
import inspect
import platform
import shutil
import subprocess
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
import networkx as nx
from PIL import Image
from ..debugging import draw_graph, get_printable_graph
from ..operator_graph import OPGraph
from ..representation import intermediate as ir
from ..values import BaseValue
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
class CompilationArtifacts:
"""Class that conveys information about compilation process."""
output_directory: Path
source_code_of_the_function_to_compile: Optional[str]
parameters_of_the_function_to_compile: Dict[str, str]
drawings_of_operation_graphs: Dict[str, Image.Image]
textual_representations_of_operation_graphs: Dict[str, str]
final_operation_graph: Optional[OPGraph]
bounds_of_the_final_operation_graph: Optional[Dict[ir.IntermediateNode, Dict[str, Any]]]
mlir_of_the_final_operation_graph: Optional[str]
def __init__(self, output_directory: Path = DEFAULT_OUTPUT_DIRECTORY):
self.output_directory = output_directory
self.source_code_of_the_function_to_compile = None
self.parameters_of_the_function_to_compile = {}
self.drawings_of_operation_graphs = {}
self.textual_representations_of_operation_graphs = {}
self.final_operation_graph = None
self.bounds_of_the_final_operation_graph = None
self.mlir_of_the_final_operation_graph = None
def add_function_to_compile(self, function: Union[Callable, str]):
"""Adds the function to compile to artifacts.
Args:
function (Union[Callable, str]): the function to compile or source code of it
Returns:
None
"""
self.source_code_of_the_function_to_compile = (
function if isinstance(function, str) else inspect.getsource(function)
)
def add_parameter_of_function_to_compile(self, name: str, value: Union[BaseValue, str]):
"""Adds a parameter of the function to compile to the artifacts.
Args:
name (str): name of the parameter
value (Union[BaseValue, str]): value of the parameter or textual representation of it
Returns:
None
"""
self.parameters_of_the_function_to_compile[name] = str(value)
def add_operation_graph(self, name: str, operation_graph: OPGraph):
"""Adds an operation graph to the artifacts.
Args:
name (str): name of the graph
operation_graph (OPGraph): the operation graph itself
Returns:
None
"""
drawing = draw_graph(operation_graph)
textual_representation = get_printable_graph(operation_graph, show_data_types=True)
self.drawings_of_operation_graphs[name] = drawing
self.textual_representations_of_operation_graphs[name] = textual_representation
self.final_operation_graph = operation_graph
def add_final_operation_graph_bounds(self, bounds: Dict[ir.IntermediateNode, Dict[str, Any]]):
"""Adds the bounds of the final operation graph to the artifacts.
Args:
bounds (Dict[ir.IntermediateNode, Dict[str, Any]]): the bound dictionary
Returns:
None
"""
assert self.final_operation_graph is not None
self.bounds_of_the_final_operation_graph = bounds
def add_final_operation_graph_mlir(self, mlir: str):
"""Adds the mlir of the final operation graph to the artifacts.
Args:
mlir (str): the mlir code of the final operation graph
Returns:
None
"""
assert self.final_operation_graph is not None
self.mlir_of_the_final_operation_graph = mlir
def export(self):
"""Exports the artifacts to a the output directory.
Returns:
None
"""
output_directory = self.output_directory
if output_directory.exists():
shutil.rmtree(output_directory)
output_directory.mkdir()
with open(output_directory.joinpath("environment.txt"), "w") as f:
f.write(f"{platform.platform()} {platform.version()}\n")
f.write(f"Python {platform.python_version()}\n")
with open(output_directory.joinpath("requirements.txt"), "w") as f:
# example `pip list` output
# Package Version
# ----------------------------- ---------
# alabaster 0.7.12
# appdirs 1.4.4
# ... ...
# ... ...
# wrapt 1.12.1
# zipp 3.5.0
pip_process = subprocess.run(["pip", "list"], stdout=subprocess.PIPE, check=True)
dependencies = iter(pip_process.stdout.decode("utf-8").split("\n"))
# skip 'Package ... Version' line
next(dependencies)
# skip '------- ... -------' line
next(dependencies)
for dependency in dependencies:
tokens = [token for token in dependency.split(" ") if token != ""]
if len(tokens) == 0:
continue
name = tokens[0]
version = tokens[1]
f.write(f"{name}=={version}\n")
if self.source_code_of_the_function_to_compile is not None:
with open(output_directory.joinpath("function.txt"), "w") as f:
f.write(self.source_code_of_the_function_to_compile)
if len(self.parameters_of_the_function_to_compile) > 0:
with open(output_directory.joinpath("parameters.txt"), "w") as f:
for name, parameter in self.parameters_of_the_function_to_compile.items():
f.write(f"{name} :: {parameter}\n")
drawings = self.drawings_of_operation_graphs.items()
for index, (name, drawing) in enumerate(drawings):
identifier = CompilationArtifacts._identifier(index, name)
drawing.save(output_directory.joinpath(f"{identifier}.png"))
textual_representations = self.textual_representations_of_operation_graphs.items()
for index, (name, representation) in enumerate(textual_representations):
identifier = CompilationArtifacts._identifier(index, name)
with open(output_directory.joinpath(f"{identifier}.txt"), "w") as f:
f.write(f"{representation}\n")
if self.bounds_of_the_final_operation_graph is not None:
assert self.final_operation_graph is not None
with open(output_directory.joinpath("bounds.txt"), "w") as f:
# TODO:
# if nx.topological_sort is not deterministic between calls,
# the lines below will not work properly
# thus, we may want to change this in the future
for index, node in enumerate(nx.topological_sort(self.final_operation_graph.graph)):
bounds = self.bounds_of_the_final_operation_graph.get(node)
assert bounds is not None
f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n")
if self.mlir_of_the_final_operation_graph is not None:
assert self.final_operation_graph is not None
with open(output_directory.joinpath("mlir.txt"), "w") as f:
f.write(self.mlir_of_the_final_operation_graph)
@staticmethod
def _identifier(index, name):
return f"{index + 1}.{name}.graph"