feat: implement compilation module

This commit is contained in:
Umut
2022-04-04 13:29:27 +02:00
parent 92651a12ee
commit 4a6c728f8f
7 changed files with 1256 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
"""
Declaration of `concrete.numpy.compilation` namespace.
"""
from .artifacts import CompilationArtifacts
from .circuit import Circuit
from .compiler import Compiler, EncryptionStatus
from .configuration import CompilationConfiguration
from .decorator import compiler

View File

@@ -0,0 +1,207 @@
"""
Declaration of `CompilationArtifacts` class.
"""
import inspect
import platform
import shutil
import subprocess
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
import networkx as nx
from ..representation import Graph, Node
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
class CompilationArtifacts:
"""
CompilationArtifacts class, to export information about the compilation process.
"""
output_directory: Path
source_code: Optional[str]
parameter_encryption_statuses: Dict[str, str]
drawings_of_graphs: Dict[str, str]
textual_representations_of_graphs: Dict[str, str]
final_graph: Optional[Graph]
bounds_of_the_final_graph: Optional[Dict[Node, Dict[str, Any]]]
mlir_to_compile: Optional[str]
def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
self.output_directory = Path(output_directory)
self.source_code = None
self.parameter_encryption_statuses = {}
self.drawings_of_graphs = {}
self.textual_representations_of_graphs = {}
self.final_graph = None
self.bounds_of_the_final_graph = None
self.mlir_to_compile = None
def add_source_code(self, function: Union[str, Callable]):
"""
Add source code of the function being compiled.
Args:
function (Union[str, Callable]):
either the source code of the function or the function itself
"""
try:
self.source_code = (
function if isinstance(function, str) else inspect.getsource(function)
)
except OSError: # pragma: no cover
self.source_code = "unavailable"
def add_parameter_encryption_status(self, name: str, encryption_status: str):
"""
Add parameter encryption status of a parameter of the function being compiled.
Args:
name (str):
name of the parameter
encryption_status (str):
encryption status of the parameter
"""
self.parameter_encryption_statuses[name] = encryption_status
def add_graph(self, name: str, graph: Graph):
"""
Add a representation of the function being compiled.
Args:
name (str):
name of the graph (e.g., initial, optimized, final)
graph (Graph):
a representation of the function being compiled
"""
textual_representation = graph.format()
self.textual_representations_of_graphs[name] = textual_representation
try:
drawing = graph.draw()
self.drawings_of_graphs[name] = str(drawing)
except ImportError as error: # pragma: no cover
if "pygraphviz" in str(error):
pass
else:
raise error
self.final_graph = graph
def add_final_graph_bounds(self, bounds: Dict[Node, Dict[str, Any]]):
"""
Add bounds of the latest computation graph.
Args:
bounds (Dict[Node, Dict[str, Any]]):
bounds of the latest computation graph
"""
assert self.final_graph is not None
self.bounds_of_the_final_graph = bounds
def add_mlir_to_compile(self, mlir: str):
"""
Add textual representation of the resulting MLIR.
Args:
mlir (str):
textual representation of the resulting MLIR
"""
self.mlir_to_compile = mlir
def export(self):
"""
Export the collected information to `self.output_directory`.
"""
output_directory = self.output_directory
if output_directory.exists():
shutil.rmtree(output_directory)
output_directory.mkdir(parents=True)
with open(output_directory.joinpath("environment.txt"), "w", encoding="utf-8") as f:
f.write(f"{platform.platform()} {platform.version()}\n")
f.write(f"Python {platform.python_version()}\n")
with open(output_directory.joinpath("requirements.txt"), "w", encoding="utf-8") as f:
# example `pip list` output
# Package Version
# ----------------------------- ---------
# alabaster 0.7.12
# appdirs 1.4.4
# ... ...
# ... ...
# wrapt 1.12.1
# zipp 3.5.0
pip_process = subprocess.run(
["pip", "--disable-pip-version-check", "list"], stdout=subprocess.PIPE, check=True
)
dependencies = iter(pip_process.stdout.decode("utf-8").split("\n"))
# skip 'Package ... Version' line
next(dependencies)
# skip '------- ... -------' line
next(dependencies)
for dependency in dependencies:
tokens = [token for token in dependency.split(" ") if token != ""]
if len(tokens) == 0:
continue
name = tokens[0]
version = tokens[1]
f.write(f"{name}=={version}\n")
if self.source_code is not None:
with open(output_directory.joinpath("function.txt"), "w", encoding="utf-8") as f:
f.write(self.source_code)
if len(self.parameter_encryption_statuses) > 0:
with open(output_directory.joinpath("parameters.txt"), "w", encoding="utf-8") as f:
for name, parameter in self.parameter_encryption_statuses.items():
f.write(f"{name} :: {parameter}\n")
drawings = self.drawings_of_graphs.items()
for index, (name, drawing_filename) in enumerate(drawings):
identifier = f"{index + 1}.{name}.graph"
shutil.copy(drawing_filename, output_directory.joinpath(f"{identifier}.png"))
textual_representations = self.textual_representations_of_graphs.items()
for index, (name, representation) in enumerate(textual_representations):
identifier = f"{index + 1}.{name}.graph"
with open(output_directory.joinpath(f"{identifier}.txt"), "w", encoding="utf-8") as f:
f.write(f"{representation}\n")
if self.bounds_of_the_final_graph is not None:
assert self.final_graph is not None
with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f:
for index, node in enumerate(nx.topological_sort(self.final_graph.graph)):
bounds = self.bounds_of_the_final_graph.get(node)
f.write(f"%{index} :: [{bounds['min']}, {bounds['max']}]\n")
if self.mlir_to_compile is not None:
assert self.final_graph is not None
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
f.write(f"{self.mlir_to_compile}\n")

View File

@@ -0,0 +1,146 @@
"""
Declaration of `Circuit` class.
"""
from pathlib import Path
from typing import List, Optional, Tuple, Union, cast
import numpy as np
from concrete.compiler import CompilerEngine
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Graph
from ..values import Value
class Circuit:
"""
Circuit class, to combine computation graph and compiler engine into a single object.
"""
graph: Graph
engine: CompilerEngine
def __init__(self, graph: Graph, engine: CompilerEngine):
self.graph = graph
self.engine = engine
def __str__(self):
return self.graph.format()
def draw(
self,
show: bool = False,
horizontal: bool = False,
save_to: Optional[Union[Path, str]] = None,
) -> Path:
"""
Draw the `self.graph` and optionally save/show the drawing.
note that this function requires the python `pygraphviz` package
which itself requires the installation of `graphviz` packages
see https://pygraphviz.github.io/documentation/stable/install.html
Args:
show (bool, default = False):
whether to show the drawing using matplotlib or not
horizontal (bool, default = False):
whether to draw horizontally or not
save_to (Optional[Path], default = None):
path to save the drawing
a temporary file will be used if it's None
Returns:
Path:
path to the saved drawing
"""
return self.graph.draw(show, horizontal, save_to)
def run(
self,
*args: Union[int, np.ndarray],
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
"""
Encrypt inputs, evaluate the circuit, and decrypt the outputs in one go.
Args:
*args (List[Union[int, numpy.ndarray]]):
inputs to the engine
Returns:
Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
result of the homomorphic evaluation
"""
if len(args) != len(self.graph.input_nodes):
raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}")
sanitized_args = {}
for index, node in self.graph.input_nodes.items():
arg = args[index]
is_valid = isinstance(arg, (int, np.integer)) or (
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
)
expected_value = node.output
assert_that(isinstance(expected_value.dtype, Integer))
expected_dtype = cast(Integer, expected_value.dtype)
if is_valid:
expected_min = expected_dtype.min()
expected_max = expected_dtype.max()
expected_shape = expected_value.shape
actual_min = arg if isinstance(arg, int) else arg.min()
actual_max = arg if isinstance(arg, int) else arg.max()
actual_shape = () if isinstance(arg, int) else arg.shape
is_valid = (
actual_min >= expected_min
and actual_max <= expected_max
and actual_shape == expected_shape
)
if is_valid:
sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8)
if not is_valid:
actual_value = Value.of(arg, is_encrypted=expected_value.is_encrypted)
raise ValueError(
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
)
results = self.engine.run(*[sanitized_args[i] for i in range(len(sanitized_args))])
if not isinstance(results, tuple):
results = (results,)
sanitized_results: List[Union[int, np.ndarray]] = []
for index, node in self.graph.output_nodes.items():
expected_value = node.output
assert_that(isinstance(expected_value.dtype, Integer))
expected_dtype = cast(Integer, expected_value.dtype)
n = expected_dtype.bit_width
result = results[index] % (2 ** n)
if expected_dtype.is_signed:
if isinstance(result, int):
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2 ** n)
sanitized_results.append(sanititzed_result)
else:
result = result.astype(np.longlong) # to prevent overflows in numpy
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2 ** n))
sanitized_results.append(sanititzed_result.astype(np.int8))
else:
sanitized_results.append(
result if isinstance(result, int) else result.astype(np.uint8)
)
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)

View File

@@ -0,0 +1,352 @@
"""
Declaration of `Compiler` class.
"""
import inspect
import os
import traceback
from enum import Enum, unique
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import CompilerEngine
from ..mlir import GraphConverter
from ..representation import Graph
from ..tracing import Tracer
from ..values import Value
from .artifacts import CompilationArtifacts
from .circuit import Circuit
from .configuration import CompilationConfiguration
from .utils import fuse
@unique
class EncryptionStatus(str, Enum):
"""
EncryptionStatus enum, to represent encryption status of parameters.
"""
CLEAR = "clear"
ENCRYPTED = "encrypted"
class Compiler:
"""
Compiler class, to glue the compilation pipeline.
"""
function: Callable
parameter_encryption_statuses: Dict[str, EncryptionStatus]
configuration: CompilationConfiguration
artifacts: CompilationArtifacts
inputset: List[Any]
graph: Optional[Graph]
def __init__(
self,
function: Callable,
parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]],
configuration: Optional[CompilationConfiguration] = None,
artifacts: Optional[CompilationArtifacts] = None,
):
signature = inspect.signature(function)
missing_args = list(signature.parameters)
for arg in parameter_encryption_statuses.keys():
if arg in signature.parameters:
missing_args.remove(arg)
if len(missing_args) != 0:
parameter_str = repr(missing_args[0])
for arg in missing_args[1:-1]:
parameter_str += f", {repr(arg)}"
if len(missing_args) != 1:
parameter_str += f" and {repr(missing_args[-1])}"
raise ValueError(
f"Encryption status{'es' if len(missing_args) > 1 else ''} "
f"of parameter{'s' if len(missing_args) > 1 else ''} "
f"{parameter_str} of function '{function.__name__}' "
f"{'are' if len(missing_args) > 1 else 'is'} not provided"
)
additional_args = parameter_encryption_statuses.keys() - signature.parameters.keys()
for arg in additional_args:
del parameter_encryption_statuses[arg]
self.function = function # type: ignore
self.parameter_encryption_statuses = {
param: EncryptionStatus(status.lower())
for param, status in parameter_encryption_statuses.items()
}
self.configuration = (
configuration if configuration is not None else CompilationConfiguration()
)
self.artifacts = artifacts if artifacts is not None else CompilationArtifacts()
self.inputset = []
self.graph = None
self.artifacts.add_source_code(function)
for param, encryption_status in parameter_encryption_statuses.items():
self.artifacts.add_parameter_encryption_status(param, encryption_status)
def __call__(
self,
*args: Any,
**kwargs: Any,
) -> Union[
np.bool_,
np.integer,
np.floating,
np.ndarray,
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
]:
if len(kwargs) != 0:
raise RuntimeError(
f"Calling function '{self.function.__name__}' with kwargs is not supported"
)
sample = args[0] if len(args) == 1 else args
if self.graph is None:
self._trace(sample)
assert self.graph is not None
self.inputset.append(sample)
return self.graph(*args)
def _trace(self, sample: Union[Any, Tuple[Any, ...]]):
"""
Trace the function and fuse the resulting graph with a sample input.
Args:
sample (Union[Any, Tuple[Any, ...]]):
sample to use for tracing
"""
parameters = {
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
for arg, (param, status) in zip(
sample if len(self.parameter_encryption_statuses) > 1 else (sample,),
self.parameter_encryption_statuses.items(),
)
}
self.graph = Tracer.trace(self.function, parameters)
self.artifacts.add_graph("initial", self.graph)
fuse(self.graph, self.artifacts)
def _evaluate(
self,
action: str,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]],
):
"""
Trace, fuse, measure bounds, and update values in the resulting graph in one go.
Args:
action (str):
action being performed (e.g., "trace", "compile")
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
"""
if inputset is not None:
for input_ in inputset:
self.inputset.append(input_)
if self.graph is None:
try:
first_sample = next(iter(self.inputset))
except StopIteration as error:
raise RuntimeError(
f"{action} function '{self.function.__name__}' "
f"without an inputset is not supported"
) from error
self._trace(first_sample)
assert self.graph is not None
bounds = self.graph.measure_bounds(self.inputset)
self.artifacts.add_final_graph_bounds(bounds)
self.graph.update_with_bounds(bounds)
self.artifacts.add_graph("final", self.graph)
def trace(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
) -> Graph:
"""
Trace the function using an inputset.
Args:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
show_graph (bool, default = False):
whether to print the computation graph
Returns:
Graph:
computation graph representing the function prior to MLIR conversion
"""
try:
self._evaluate("Tracing", inputset)
assert self.graph is not None
if show_graph:
graph = self.graph.format()
longest_line = max([len(line) for line in graph.split("\n")])
try: # pragma: no cover
# this branch cannot be covered
# because `os.get_terminal_size()`
# raises an exception during tests
columns, _ = os.get_terminal_size()
if columns == 0:
columns = min(longest_line, 80)
else:
columns = min(longest_line, columns)
except OSError:
columns = min(longest_line, 80)
print()
print("Computation Graph")
print("-" * columns)
print(graph)
print("-" * columns)
print()
return self.graph
except Exception: # pragma: no cover
# this branch is reserved for unexpected issues and hence it shouldn't be tested
# if it could be tested, we would have fixed the underlying issue
# if the user desires so,
# we need to export all the information we have about the compilation
if self.configuration.dump_artifacts_on_unexpected_failures:
self.artifacts.export()
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise
def compile(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
show_mlir: bool = False,
) -> Circuit:
"""
Compile the function using an inputset.
Args:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
show_graph (bool, default = False):
whether to print the computation graph
show_mlir (bool, default = False):
whether to print the compiled mlir
Returns:
Circuit:
compiled circuit
"""
try:
self._evaluate("Compiling", inputset)
assert self.graph is not None
mlir = GraphConverter.convert(self.graph)
self.artifacts.add_mlir_to_compile(mlir)
if show_graph or show_mlir:
graph = self.graph.format() if show_graph else ""
longest_graph_line = max([len(line) for line in graph.split("\n")])
longest_mlir_line = max([len(line) for line in mlir.split("\n")])
longest_line = max(longest_graph_line, longest_mlir_line)
try: # pragma: no cover
# this branch cannot be covered
# because `os.get_terminal_size()`
# raises an exception during tests
columns, _ = os.get_terminal_size()
if columns == 0:
columns = min(longest_line, 80)
else:
columns = min(longest_line, columns)
except OSError:
columns = min(longest_line, 80)
if show_graph:
print()
print("Computation Graph")
print("-" * columns)
print(graph)
print("-" * columns)
print()
if show_mlir:
print("\n" if not show_graph else "", end="")
print("MLIR")
print("-" * columns)
print(mlir)
print("-" * columns)
print()
engine = CompilerEngine()
if self.configuration.use_insecure_key_cache:
assert self.configuration.enable_unsafe_features
location = CompilationConfiguration.insecure_key_cache_location()
engine.compile_fhe(mlir, unsecure_key_set_cache_path=location)
else:
# this branch is not covered because all tests use key cache to speed up tests
engine.compile_fhe(mlir) # pragma: no cover
return Circuit(self.graph, engine)
except Exception: # pragma: no cover
# this branch is reserved for unexpected issues and hence it shouldn't be tested
# if it could be tested, we would have fixed the underlying issue
# if the user desires so,
# we need to export all the information we have about the compilation
if self.configuration.dump_artifacts_on_unexpected_failures:
self.artifacts.export()
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise

View File

@@ -0,0 +1,42 @@
"""
Declaration of `CompilationConfiguration` class.
"""
from typing import Optional
_INSECURE_KEY_CACHE_LOCATION: Optional[str] = None
class CompilationConfiguration:
"""
CompilationConfiguration class, to allow the compilation process to be customized.
"""
dump_artifacts_on_unexpected_failures: bool
enable_unsafe_features: bool
use_insecure_key_cache: bool
def __init__(
self,
dump_artifacts_on_unexpected_failures: bool = True,
enable_unsafe_features: bool = False,
use_insecure_key_cache: bool = False,
):
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
self.enable_unsafe_features = enable_unsafe_features
self.use_insecure_key_cache = use_insecure_key_cache
if not enable_unsafe_features and use_insecure_key_cache:
raise RuntimeError("Insecure key cache cannot be used without enabling unsafe features")
@staticmethod
def insecure_key_cache_location() -> Optional[str]:
"""
Get insecure key cache location.
Returns:
Optional[str]:
insecure key cache location if configured, None otherwise
"""
return _INSECURE_KEY_CACHE_LOCATION

View File

@@ -0,0 +1,105 @@
"""
Declaration of `compiler` decorator.
"""
from typing import Any, Callable, Iterable, Mapping, Optional, Tuple, Union
from ..representation import Graph
from .artifacts import CompilationArtifacts
from .circuit import Circuit
from .compiler import Compiler, EncryptionStatus
from .configuration import CompilationConfiguration
def compiler(
parameters: Mapping[str, EncryptionStatus],
configuration: Optional[CompilationConfiguration] = None,
artifacts: Optional[CompilationArtifacts] = None,
):
"""
Provide an easy interface for compilation.
Args:
parameters (Dict[str, EncryptionStatus]):
encryption statuses of the parameters of the function to compile
configuration(Optional[CompilationConfiguration], default = None):
configuration to use for compilation
artifacts (Optional[CompilationArtifacts], default = None):
artifacts to store information about compilation
"""
def decoration(function: Callable):
class Compilable:
"""
Compilable class, to wrap a function and provide methods to trace and compile it.
"""
function: Callable
compiler: Compiler
def __init__(self, function: Callable):
self.function = function # type: ignore
self.compiler = Compiler(
self.function,
dict(parameters),
configuration,
artifacts,
)
def __call__(self, *args, **kwargs) -> Any:
self.compiler(*args, **kwargs)
return self.function(*args, **kwargs)
def trace(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
) -> Graph:
"""
Trace the function into computation graph.
Args:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
show_graph (bool, default = False):
whether to print the computation graph
Returns:
Graph:
computation graph representing the function prior to MLIR conversion
"""
return self.compiler.trace(inputset, show_graph)
def compile(
self,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
show_graph: bool = False,
show_mlir: bool = False,
) -> Circuit:
"""
Compile the function into a circuit.
Args:
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
show_graph (bool, default = False):
whether to print the computation graph
show_mlir (bool, default = False):
whether to print the compiled mlir
Returns:
Circuit:
compiled circuit
"""
return self.compiler.compile(inputset, show_graph, show_mlir)
return Compilable(function)
return decoration

View File

@@ -0,0 +1,395 @@
"""
Declaration of various functions and constants related to compilation.
"""
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Set, Tuple
import networkx as nx
from ..dtypes import Float, Integer
from ..representation import Graph, Node, Operation
from .artifacts import CompilationArtifacts
def fuse(graph: Graph, artifacts: Optional[CompilationArtifacts] = None):
"""
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
Args:
graph (Graph):
graph to search and update
artifacts (Optional[CompilationArtifacts], default = None):
compilation artifacts to store information about the fusing process
"""
nx_graph = graph.graph
processed_terminal_nodes: Set[Node] = set()
while True:
float_subgraph_to_fuse = find_float_subgraph_with_unique_terminal_node(
nx_graph,
processed_terminal_nodes,
)
if float_subgraph_to_fuse is None:
break
all_nodes, start_nodes, terminal_node = float_subgraph_to_fuse
processed_terminal_nodes.add(terminal_node)
subgraph_conversion_result = convert_subgraph_to_subgraph_node(
nx_graph,
all_nodes,
start_nodes,
terminal_node,
)
if subgraph_conversion_result is None:
continue
fused_node, node_before_subgraph = subgraph_conversion_result
nx_graph.add_node(fused_node)
if terminal_node in graph.output_nodes.values():
output_node_to_idx: Dict[Node, List[int]] = {
out_node: [] for out_node in graph.output_nodes.values()
}
for output_idx, output_node in graph.output_nodes.items():
output_node_to_idx[output_node].append(output_idx)
for output_idx in output_node_to_idx.get(terminal_node, []):
graph.output_nodes[output_idx] = fused_node
terminal_node_succ = list(nx_graph.successors(terminal_node))
for succ in terminal_node_succ:
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
for edge_key, edge_data in succ_edge_data.items():
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
new_edge_data = deepcopy(edge_data)
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
graph.prune_useless_nodes()
if artifacts is not None:
artifacts.add_graph("after-fusing-float-operations", graph)
def find_float_subgraph_with_unique_terminal_node(
nx_graph: nx.MultiDiGraph,
processed_terminal_nodes: Set[Node],
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
"""
Find a subgraph with float computations that end with an integer output.
Args:
nx_graph (nx.MultiDiGraph):
graph to search
processed_terminal_nodes (Set[Node]):
set of terminal nodes which have already been searched for float subgraphs
Returns:
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
None if there are no such subgraphs,
tuple containing all nodes in the subgraph, start nodes of the subgraph,
and terminal node of the subgraph otherwise
"""
terminal_nodes = (
node
for node in nx_graph.nodes()
if (
node not in processed_terminal_nodes
and any(isinstance(input.dtype, Float) for input in node.inputs)
and isinstance(node.output.dtype, Integer)
)
)
try:
terminal_node = next(terminal_nodes)
except StopIteration:
return None
# networkx does not implement lowest common ancestor search for multidigraph, but we only care
# about parent relationship here and not the meaning of edges, so we can convert our
# multidigraph to a digraph and use the lca search algorithm (if needed), we create the
# equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause
# issues in our search, so we remove them.
equivalent_subgraph_without_constants = nx.DiGraph(nx_graph)
constant_nodes = [
node
for node in equivalent_subgraph_without_constants.nodes()
if node.operation == Operation.Constant
]
equivalent_subgraph_without_constants.remove_nodes_from(constant_nodes)
all_nodes: Dict[Node, None] = {}
start_single_int_output_nodes_search_from = terminal_node
while True:
all_nodes, start_nodes = find_closest_integer_output_nodes(
nx_graph,
[start_single_int_output_nodes_search_from],
all_nodes,
)
variable_start_nodes = [
start_node for start_node in start_nodes if start_node.operation != Operation.Constant
]
if len(variable_start_nodes) == 1:
break
# find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
# lca search only works for node pairs in networkx, so we progressively find the ancestors
# setting the lca by default to one of the nodes we are searching the lca for
lca = variable_start_nodes.pop()
while len(variable_start_nodes) > 0 and lca is not None:
node_to_find_new_lca = variable_start_nodes.pop()
ancestors_of_lca = nx.ancestors(
equivalent_subgraph_without_constants,
lca,
)
ancestors_of_node_to_find_new_lca = nx.ancestors(
equivalent_subgraph_without_constants,
node_to_find_new_lca,
)
lca_is_ancestor_of_node_to_find_new_lca = lca in ancestors_of_node_to_find_new_lca
node_to_find_new_lca_is_ancestor_of_lca = node_to_find_new_lca in ancestors_of_lca
if lca_is_ancestor_of_node_to_find_new_lca or node_to_find_new_lca_is_ancestor_of_lca:
lca = lca if lca_is_ancestor_of_node_to_find_new_lca else node_to_find_new_lca
continue
lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor(
equivalent_subgraph_without_constants, lca, node_to_find_new_lca, default=None
)
# if subgraph cannot be fused because there is no way to find a common ancestor, break
if lca is None:
break
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
all_nodes = add_nodes_from_to(nx_graph, start_nodes, {lca: None}, all_nodes)
# if `lca` is a valid starting node for fusing break
if isinstance(lca.output.dtype, Integer):
# `lca` is the new start node
start_nodes = {lca: None}
break
# otherwise, push a little further
# (e.g., if there is a node just before, which has an integer output)
start_single_int_output_nodes_search_from = lca
return all_nodes, start_nodes, terminal_node
def find_closest_integer_output_nodes(
nx_graph: nx.MultiDiGraph,
start_nodes: List[Node],
all_nodes: Dict[Node, None],
) -> Tuple[Dict[Node, None], Dict[Node, None]]:
"""
Find the closest upstream integer output nodes to a set of start nodes in a graph.
Args:
nx_graph (nx.MultiDiGraph):
graph to search
start_nodes (List[Node]):
nodes from which to start the search
all_nodes (Dict[Node, None]):
set of nodes to be extended with visited nodes during the search
Returns:
Tuple[Dict[Node, None], Dict[Node, None]]:
tuple containing extended `all_nodes` and integer output nodes closest to `start_nodes`
"""
closest_integer_output_nodes: Dict[Node, None] = {}
visited_nodes: Set[Node] = set()
current_nodes = {start_node: None for start_node in start_nodes}
while current_nodes:
next_nodes: Dict[Node, None] = {}
for node in current_nodes:
if node not in visited_nodes:
visited_nodes.add(node)
all_nodes.update({node: None})
for pred in nx_graph.predecessors(node):
if isinstance(pred.output.dtype, Integer):
closest_integer_output_nodes.update({pred: None})
all_nodes.update({pred: None})
else:
next_nodes.update({pred: None})
current_nodes = next_nodes
return all_nodes, closest_integer_output_nodes
def add_nodes_from_to(
nx_graph: nx.MultiDiGraph,
from_nodes: Iterable[Node],
to_nodes: Dict[Node, None],
all_nodes: Dict[Node, None],
) -> Dict[Node, None]:
"""
Add nodes from `from_nodes` to `to_nodes`, to `all_nodes`.
Args:
nx_graph (nx.MultiDiGraph):
graph to traverse
from_nodes (Iterable[Node]):
nodes from which extending `all_nodes` start
to_nodes (Dict[Node, None]):
nodes to which extending `all_nodes` stop
all_nodes (Dict[Node, None]):
nodes to be extended
Returns:
Dict[Node, None]:
extended `all_nodes`
"""
all_nodes.update(to_nodes)
visited_nodes: Set[Node] = set()
current_nodes = {from_node: None for from_node in from_nodes}
while current_nodes:
next_nodes: Dict[Node, None] = {}
for node in current_nodes:
if node not in visited_nodes:
visited_nodes.add(node)
all_nodes.update({node: None})
if node not in to_nodes:
predecessors = nx_graph.predecessors(node)
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
current_nodes = next_nodes
return all_nodes
def convert_subgraph_to_subgraph_node(
nx_graph: nx.MultiDiGraph,
all_nodes: Dict[Node, None],
start_nodes: Dict[Node, None],
terminal_node: Node,
) -> Optional[Tuple[Node, Node]]:
"""
Convert a subgraph to Operation.Generic node.
Args:
nx_graph (nx.MultiDiGraph):
orginal networkx graph
all_nodes (Dict[Node, None]):
all nodes in the subgraph
start_nodes (Dict[Node, None]):
start nodes of the subgraph
terminal_node (Node):
terminal node of the subgraph
Returns:
Optional[Tuple[Node, Node]]:
None if the subgraph cannot be fused,
subgraph node and its predecessor otherwise
"""
variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant]
if len(variable_input_nodes) != 1:
return None
variable_input_node = variable_input_nodes[0]
if not subgraph_can_be_fused(all_nodes, variable_input_node):
return None
nx_subgraph = nx.MultiDiGraph(nx_graph)
nodes_to_remove = [node for node in nx_subgraph.nodes() if node not in all_nodes]
nx_subgraph.remove_nodes_from(nodes_to_remove)
subgraph_variable_input_node = Node.input("input", deepcopy(variable_input_node.output))
nx_subgraph.add_node(subgraph_variable_input_node)
variable_input_node_successors = {
node: None for node in all_nodes if node in nx_graph.succ[variable_input_node]
}
for successor in variable_input_node_successors:
edges = deepcopy(nx_subgraph.get_edge_data(variable_input_node, successor))
for edge_key, edge_data in edges.items():
nx_subgraph.remove_edge(variable_input_node, successor, key=edge_key)
new_edge_data = deepcopy(edge_data)
nx_subgraph.add_edge(
subgraph_variable_input_node,
successor,
key=edge_key,
**new_edge_data,
)
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
subgraph_node = Node.generic(
"subgraph",
subgraph_variable_input_node.inputs,
terminal_node.output,
lambda x, subgraph, terminal_node: subgraph.evaluate(x)[terminal_node],
kwargs={
"subgraph": subgraph,
"terminal_node": terminal_node,
},
)
return subgraph_node, variable_input_node
def subgraph_can_be_fused(
all_nodes: Dict[Node, None],
variable_input_node: Node,
) -> bool:
"""
Determine if a subgraph can be fused.
e.g.,
shuffling or reshaping a tensor make fusing impossible as there should be a one-to-one mapping
between each cell of the input and each cell of the output for table lookups
Args:
all_nodes (Dict[Node, None]):
all nodes in the subgraph
variable_input_node (Node):
variable input node to the subgraph
Returns:
bool:
True if subgraph can be fused,
False otherwise
"""
constant_nodes_with_bigger_size_than_variable_input = [
node
for node in all_nodes
if (
node.operation == Operation.Constant
and node.output.size > variable_input_node.output.size
)
]
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
return False
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
for node in non_constant_nodes:
if node.output.shape != variable_input_node.output.shape:
return False
return True