mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
208 lines
7.6 KiB
Python
208 lines
7.6 KiB
Python
"""Module for compilation artifacts."""
|
|
|
|
import inspect
|
|
import platform
|
|
import shutil
|
|
import subprocess
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, Optional, Union
|
|
|
|
import networkx as nx
|
|
from PIL import Image
|
|
|
|
from ..debugging import draw_graph, get_printable_graph
|
|
from ..operator_graph import OPGraph
|
|
from ..representation import intermediate as ir
|
|
from ..values import BaseValue
|
|
|
|
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
|
|
|
|
|
|
class CompilationArtifacts:
|
|
"""Class that conveys information about compilation process."""
|
|
|
|
output_directory: Path
|
|
|
|
source_code_of_the_function_to_compile: Optional[str]
|
|
parameters_of_the_function_to_compile: Dict[str, str]
|
|
|
|
drawings_of_operation_graphs: Dict[str, Image.Image]
|
|
textual_representations_of_operation_graphs: Dict[str, str]
|
|
|
|
final_operation_graph: Optional[OPGraph]
|
|
bounds_of_the_final_operation_graph: Optional[Dict[ir.IntermediateNode, Dict[str, Any]]]
|
|
mlir_of_the_final_operation_graph: Optional[str]
|
|
|
|
def __init__(self, output_directory: Path = DEFAULT_OUTPUT_DIRECTORY):
|
|
self.output_directory = output_directory
|
|
|
|
self.source_code_of_the_function_to_compile = None
|
|
self.parameters_of_the_function_to_compile = {}
|
|
|
|
self.drawings_of_operation_graphs = {}
|
|
self.textual_representations_of_operation_graphs = {}
|
|
|
|
self.final_operation_graph = None
|
|
self.bounds_of_the_final_operation_graph = None
|
|
self.mlir_of_the_final_operation_graph = None
|
|
|
|
def add_function_to_compile(self, function: Union[Callable, str]):
|
|
"""Adds the function to compile to artifacts.
|
|
|
|
Args:
|
|
function (Union[Callable, str]): the function to compile or source code of it
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
self.source_code_of_the_function_to_compile = (
|
|
function if isinstance(function, str) else inspect.getsource(function)
|
|
)
|
|
|
|
def add_parameter_of_function_to_compile(self, name: str, value: Union[BaseValue, str]):
|
|
"""Adds a parameter of the function to compile to the artifacts.
|
|
|
|
Args:
|
|
name (str): name of the parameter
|
|
value (Union[BaseValue, str]): value of the parameter or textual representation of it
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
self.parameters_of_the_function_to_compile[name] = str(value)
|
|
|
|
def add_operation_graph(self, name: str, operation_graph: OPGraph):
|
|
"""Adds an operation graph to the artifacts.
|
|
|
|
Args:
|
|
name (str): name of the graph
|
|
operation_graph (OPGraph): the operation graph itself
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
drawing = draw_graph(operation_graph)
|
|
textual_representation = get_printable_graph(operation_graph, show_data_types=True)
|
|
|
|
self.drawings_of_operation_graphs[name] = drawing
|
|
self.textual_representations_of_operation_graphs[name] = textual_representation
|
|
|
|
self.final_operation_graph = operation_graph
|
|
|
|
def add_final_operation_graph_bounds(self, bounds: Dict[ir.IntermediateNode, Dict[str, Any]]):
|
|
"""Adds the bounds of the final operation graph to the artifacts.
|
|
|
|
Args:
|
|
bounds (Dict[ir.IntermediateNode, Dict[str, Any]]): the bound dictionary
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
assert self.final_operation_graph is not None
|
|
self.bounds_of_the_final_operation_graph = bounds
|
|
|
|
def add_final_operation_graph_mlir(self, mlir: str):
|
|
"""Adds the mlir of the final operation graph to the artifacts.
|
|
|
|
Args:
|
|
mlir (str): the mlir code of the final operation graph
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
assert self.final_operation_graph is not None
|
|
self.mlir_of_the_final_operation_graph = mlir
|
|
|
|
def export(self):
|
|
"""Exports the artifacts to a the output directory.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
output_directory = self.output_directory
|
|
if output_directory.exists():
|
|
shutil.rmtree(output_directory)
|
|
output_directory.mkdir()
|
|
|
|
with open(output_directory.joinpath("environment.txt"), "w") as f:
|
|
f.write(f"{platform.platform()} {platform.version()}\n")
|
|
f.write(f"Python {platform.python_version()}\n")
|
|
|
|
with open(output_directory.joinpath("requirements.txt"), "w") as f:
|
|
# example `pip list` output
|
|
|
|
# Package Version
|
|
# ----------------------------- ---------
|
|
# alabaster 0.7.12
|
|
# appdirs 1.4.4
|
|
# ... ...
|
|
# ... ...
|
|
# wrapt 1.12.1
|
|
# zipp 3.5.0
|
|
|
|
pip_process = subprocess.run(["pip", "list"], stdout=subprocess.PIPE, check=True)
|
|
dependencies = iter(pip_process.stdout.decode("utf-8").split("\n"))
|
|
|
|
# skip 'Package ... Version' line
|
|
next(dependencies)
|
|
|
|
# skip '------- ... -------' line
|
|
next(dependencies)
|
|
|
|
for dependency in dependencies:
|
|
tokens = [token for token in dependency.split(" ") if token != ""]
|
|
if len(tokens) == 0:
|
|
continue
|
|
|
|
name = tokens[0]
|
|
version = tokens[1]
|
|
|
|
f.write(f"{name}=={version}\n")
|
|
|
|
if self.source_code_of_the_function_to_compile is not None:
|
|
with open(output_directory.joinpath("function.txt"), "w") as f:
|
|
f.write(self.source_code_of_the_function_to_compile)
|
|
|
|
if len(self.parameters_of_the_function_to_compile) > 0:
|
|
with open(output_directory.joinpath("parameters.txt"), "w") as f:
|
|
for name, parameter in self.parameters_of_the_function_to_compile.items():
|
|
f.write(f"{name} :: {parameter}\n")
|
|
|
|
drawings = self.drawings_of_operation_graphs.items()
|
|
for index, (name, drawing) in enumerate(drawings):
|
|
identifier = CompilationArtifacts._identifier(index, name)
|
|
drawing.save(output_directory.joinpath(f"{identifier}.png"))
|
|
|
|
textual_representations = self.textual_representations_of_operation_graphs.items()
|
|
for index, (name, representation) in enumerate(textual_representations):
|
|
identifier = CompilationArtifacts._identifier(index, name)
|
|
with open(output_directory.joinpath(f"{identifier}.txt"), "w") as f:
|
|
f.write(f"{representation}\n")
|
|
|
|
if self.bounds_of_the_final_operation_graph is not None:
|
|
assert self.final_operation_graph is not None
|
|
with open(output_directory.joinpath("bounds.txt"), "w") as f:
|
|
# TODO:
|
|
# if nx.topological_sort is not deterministic between calls,
|
|
# the lines below will not work properly
|
|
# thus, we may want to change this in the future
|
|
for index, node in enumerate(nx.topological_sort(self.final_operation_graph.graph)):
|
|
bounds = self.bounds_of_the_final_operation_graph.get(node)
|
|
assert bounds is not None
|
|
f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n")
|
|
|
|
if self.mlir_of_the_final_operation_graph is not None:
|
|
assert self.final_operation_graph is not None
|
|
with open(output_directory.joinpath("mlir.txt"), "w") as f:
|
|
f.write(self.mlir_of_the_final_operation_graph)
|
|
|
|
@staticmethod
|
|
def _identifier(index, name):
|
|
return f"{index + 1}.{name}.graph"
|