mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compilation): create compilation artifacts and provide a way to export them in a textual format
This commit is contained in:
@@ -77,6 +77,7 @@ In this section, we will discuss the module structure of hdk briefly. You are en
|
||||
- hdk
|
||||
- common: types and utilities that can be used by multiple frontends (e.g., numpy, torch)
|
||||
- bounds_measurement: utilities for determining bounds of intermediate representation
|
||||
- compilation: type definitions related to compilation (e.g., compilation config, compilation artifacts)
|
||||
- data_types: type definitions of typing information of intermediate representation
|
||||
- debugging: utilities for printing/displaying intermediate representation
|
||||
- extensions: utilities that provide special functionality to our users
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""Module for shared data structures and code"""
|
||||
from . import data_types, debugging, representation
|
||||
from . import compilation, data_types, debugging, representation
|
||||
from .common_helpers import check_op_graph_is_integer_program, is_a_power_of_2
|
||||
|
||||
3
hdk/common/compilation/__init__.py
Normal file
3
hdk/common/compilation/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Module for compilation related types"""
|
||||
|
||||
from .artifacts import CompilationArtifacts
|
||||
83
hdk/common/compilation/artifacts.py
Normal file
83
hdk/common/compilation/artifacts.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Module for compilation artifacts"""
|
||||
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..debugging.draw_graph import get_printable_graph
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
|
||||
|
||||
class CompilationArtifacts:
|
||||
"""Class that conveys information about compilation process"""
|
||||
|
||||
operation_graph: Optional[OPGraph]
|
||||
bounds: Optional[Dict[ir.IntermediateNode, Dict[str, Any]]]
|
||||
|
||||
def __init__(self):
|
||||
self.operation_graph = None
|
||||
self.bounds = None
|
||||
|
||||
def export(self, output_directory: Path):
|
||||
"""Exports the artifacts in a textual format
|
||||
|
||||
Args:
|
||||
output_directory (Path): the directory to save the artifacts
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
with open(output_directory.joinpath("environment.txt"), "w") as f:
|
||||
f.write(f"{platform.platform()} {platform.version()}\n")
|
||||
f.write(f"Python {platform.python_version()}\n")
|
||||
|
||||
with open(output_directory.joinpath("requirements.txt"), "w") as f:
|
||||
# example `pip list` output
|
||||
|
||||
# Package Version
|
||||
# ----------------------------- ---------
|
||||
# alabaster 0.7.12
|
||||
# appdirs 1.4.4
|
||||
# ... ...
|
||||
# ... ...
|
||||
# wrapt 1.12.1
|
||||
# zipp 3.5.0
|
||||
|
||||
pip_process = subprocess.run(["pip", "list"], stdout=subprocess.PIPE, check=True)
|
||||
dependencies = iter(pip_process.stdout.decode("utf-8").split("\n"))
|
||||
|
||||
# skip 'Package ... Version' line
|
||||
next(dependencies)
|
||||
|
||||
# skip '------- ... -------' line
|
||||
next(dependencies)
|
||||
|
||||
for dependency in dependencies:
|
||||
tokens = [token for token in dependency.split(" ") if token != ""]
|
||||
if len(tokens) == 0:
|
||||
continue
|
||||
|
||||
name = tokens[0]
|
||||
version = tokens[1]
|
||||
|
||||
f.write(f"{name}=={version}\n")
|
||||
|
||||
if self.operation_graph is not None:
|
||||
with open(output_directory.joinpath("graph.txt"), "w") as f:
|
||||
f.write(f"{get_printable_graph(self.operation_graph)[1:]}\n")
|
||||
|
||||
if self.bounds is not None:
|
||||
with open(output_directory.joinpath("bounds.txt"), "w") as f:
|
||||
# TODO:
|
||||
# if nx.topological_sort is not deterministic between calls,
|
||||
# the lines below will not work properly
|
||||
# thus, we may want to change this in the future
|
||||
for index, node in enumerate(nx.topological_sort(self.operation_graph.graph)):
|
||||
bounds = self.bounds.get(node)
|
||||
assert bounds is not None
|
||||
f.write(f"%{index} :: [{bounds.get('min')}, {bounds.get('max')}]\n")
|
||||
@@ -1,10 +1,11 @@
|
||||
"""hnumpy compilation function"""
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, Tuple
|
||||
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
|
||||
|
||||
from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
|
||||
from hdk.hnumpy.tracing import trace_numpy_function
|
||||
|
||||
from ..common.compilation import CompilationArtifacts
|
||||
from ..common.data_types import BaseValue
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..hnumpy.tracing import trace_numpy_function
|
||||
@@ -14,6 +15,7 @@ def compile_numpy_function(
|
||||
function_to_trace: Callable,
|
||||
function_parameters: Dict[str, BaseValue],
|
||||
dataset: Iterator[Tuple[Any, ...]],
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
) -> OPGraph:
|
||||
"""Main API of hnumpy, to be able to compile an homomorphic program
|
||||
|
||||
@@ -24,6 +26,8 @@ def compile_numpy_function(
|
||||
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
|
||||
needs to be an iterator on tuples which are of the same length than the number of
|
||||
parameters in the function, and in the same order than these same parameters
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
||||
during compilation
|
||||
|
||||
Returns:
|
||||
OPGraph: currently returns a compilable graph, but later, it will return an MLIR compatible
|
||||
@@ -39,4 +43,9 @@ def compile_numpy_function(
|
||||
# Update the graph accordingly: after that, we have the compilable graph
|
||||
op_graph.update_values_with_bounds(node_bounds)
|
||||
|
||||
# Fill compilation artifacts
|
||||
if compilation_artifacts is not None:
|
||||
compilation_artifacts.operation_graph = op_graph
|
||||
compilation_artifacts.bounds = node_bounds
|
||||
|
||||
return op_graph
|
||||
|
||||
36
tests/common/compilation/test_artifacts.py
Normal file
36
tests/common/compilation/test_artifacts.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Test file for compilation artifacts"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from hdk.common.compilation import CompilationArtifacts
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.hnumpy.compile import compile_numpy_function
|
||||
|
||||
|
||||
def test_artifacts_export():
|
||||
"""Test function to check exporting compilation artifacts"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
artifacts = CompilationArtifacts()
|
||||
compile_numpy_function(
|
||||
function,
|
||||
{"x": EncryptedValue(Integer(7, True))},
|
||||
iter([(-2,), (-1,), (0,), (1,), (2,)]),
|
||||
artifacts,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
output_directory = Path(tmp)
|
||||
artifacts.export(output_directory)
|
||||
|
||||
assert output_directory.joinpath("environment.txt").exists()
|
||||
assert output_directory.joinpath("requirements.txt").exists()
|
||||
assert output_directory.joinpath("graph.txt").exists()
|
||||
assert output_directory.joinpath("bounds.txt").exists()
|
||||
|
||||
# format of those files might change in the future
|
||||
# so it is sufficient to test their existance
|
||||
Reference in New Issue
Block a user