mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat: implement compilation module
This commit is contained in:
9
concrete/numpy/compilation/__init__.py
Normal file
9
concrete/numpy/compilation/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Declaration of `concrete.numpy.compilation` namespace.
|
||||
"""
|
||||
|
||||
from .artifacts import CompilationArtifacts
|
||||
from .circuit import Circuit
|
||||
from .compiler import Compiler, EncryptionStatus
|
||||
from .configuration import CompilationConfiguration
|
||||
from .decorator import compiler
|
||||
207
concrete/numpy/compilation/artifacts.py
Normal file
207
concrete/numpy/compilation/artifacts.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Declaration of `CompilationArtifacts` class.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..representation import Graph, Node
|
||||
|
||||
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
|
||||
|
||||
|
||||
class CompilationArtifacts:
|
||||
"""
|
||||
CompilationArtifacts class, to export information about the compilation process.
|
||||
"""
|
||||
|
||||
output_directory: Path
|
||||
|
||||
source_code: Optional[str]
|
||||
parameter_encryption_statuses: Dict[str, str]
|
||||
|
||||
drawings_of_graphs: Dict[str, str]
|
||||
textual_representations_of_graphs: Dict[str, str]
|
||||
|
||||
final_graph: Optional[Graph]
|
||||
bounds_of_the_final_graph: Optional[Dict[Node, Dict[str, Any]]]
|
||||
|
||||
mlir_to_compile: Optional[str]
|
||||
|
||||
def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
|
||||
self.output_directory = Path(output_directory)
|
||||
|
||||
self.source_code = None
|
||||
self.parameter_encryption_statuses = {}
|
||||
|
||||
self.drawings_of_graphs = {}
|
||||
self.textual_representations_of_graphs = {}
|
||||
|
||||
self.final_graph = None
|
||||
self.bounds_of_the_final_graph = None
|
||||
|
||||
self.mlir_to_compile = None
|
||||
|
||||
def add_source_code(self, function: Union[str, Callable]):
|
||||
"""
|
||||
Add source code of the function being compiled.
|
||||
|
||||
Args:
|
||||
function (Union[str, Callable]):
|
||||
either the source code of the function or the function itself
|
||||
"""
|
||||
|
||||
try:
|
||||
self.source_code = (
|
||||
function if isinstance(function, str) else inspect.getsource(function)
|
||||
)
|
||||
except OSError: # pragma: no cover
|
||||
self.source_code = "unavailable"
|
||||
|
||||
def add_parameter_encryption_status(self, name: str, encryption_status: str):
|
||||
"""
|
||||
Add parameter encryption status of a parameter of the function being compiled.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the parameter
|
||||
|
||||
encryption_status (str):
|
||||
encryption status of the parameter
|
||||
"""
|
||||
|
||||
self.parameter_encryption_statuses[name] = encryption_status
|
||||
|
||||
def add_graph(self, name: str, graph: Graph):
|
||||
"""
|
||||
Add a representation of the function being compiled.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the graph (e.g., initial, optimized, final)
|
||||
|
||||
graph (Graph):
|
||||
a representation of the function being compiled
|
||||
"""
|
||||
|
||||
textual_representation = graph.format()
|
||||
self.textual_representations_of_graphs[name] = textual_representation
|
||||
|
||||
try:
|
||||
drawing = graph.draw()
|
||||
self.drawings_of_graphs[name] = str(drawing)
|
||||
except ImportError as error: # pragma: no cover
|
||||
if "pygraphviz" in str(error):
|
||||
pass
|
||||
else:
|
||||
raise error
|
||||
|
||||
self.final_graph = graph
|
||||
|
||||
def add_final_graph_bounds(self, bounds: Dict[Node, Dict[str, Any]]):
|
||||
"""
|
||||
Add bounds of the latest computation graph.
|
||||
|
||||
Args:
|
||||
bounds (Dict[Node, Dict[str, Any]]):
|
||||
bounds of the latest computation graph
|
||||
"""
|
||||
|
||||
assert self.final_graph is not None
|
||||
self.bounds_of_the_final_graph = bounds
|
||||
|
||||
def add_mlir_to_compile(self, mlir: str):
|
||||
"""
|
||||
Add textual representation of the resulting MLIR.
|
||||
|
||||
Args:
|
||||
mlir (str):
|
||||
textual representation of the resulting MLIR
|
||||
"""
|
||||
|
||||
self.mlir_to_compile = mlir
|
||||
|
||||
def export(self):
|
||||
"""
|
||||
Export the collected information to `self.output_directory`.
|
||||
"""
|
||||
|
||||
output_directory = self.output_directory
|
||||
if output_directory.exists():
|
||||
shutil.rmtree(output_directory)
|
||||
output_directory.mkdir(parents=True)
|
||||
|
||||
with open(output_directory.joinpath("environment.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{platform.platform()} {platform.version()}\n")
|
||||
f.write(f"Python {platform.python_version()}\n")
|
||||
|
||||
with open(output_directory.joinpath("requirements.txt"), "w", encoding="utf-8") as f:
|
||||
# example `pip list` output
|
||||
|
||||
# Package Version
|
||||
# ----------------------------- ---------
|
||||
# alabaster 0.7.12
|
||||
# appdirs 1.4.4
|
||||
# ... ...
|
||||
# ... ...
|
||||
# wrapt 1.12.1
|
||||
# zipp 3.5.0
|
||||
|
||||
pip_process = subprocess.run(
|
||||
["pip", "--disable-pip-version-check", "list"], stdout=subprocess.PIPE, check=True
|
||||
)
|
||||
dependencies = iter(pip_process.stdout.decode("utf-8").split("\n"))
|
||||
|
||||
# skip 'Package ... Version' line
|
||||
next(dependencies)
|
||||
|
||||
# skip '------- ... -------' line
|
||||
next(dependencies)
|
||||
|
||||
for dependency in dependencies:
|
||||
tokens = [token for token in dependency.split(" ") if token != ""]
|
||||
if len(tokens) == 0:
|
||||
continue
|
||||
|
||||
name = tokens[0]
|
||||
version = tokens[1]
|
||||
|
||||
f.write(f"{name}=={version}\n")
|
||||
|
||||
if self.source_code is not None:
|
||||
with open(output_directory.joinpath("function.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(self.source_code)
|
||||
|
||||
if len(self.parameter_encryption_statuses) > 0:
|
||||
with open(output_directory.joinpath("parameters.txt"), "w", encoding="utf-8") as f:
|
||||
for name, parameter in self.parameter_encryption_statuses.items():
|
||||
f.write(f"{name} :: {parameter}\n")
|
||||
|
||||
drawings = self.drawings_of_graphs.items()
|
||||
for index, (name, drawing_filename) in enumerate(drawings):
|
||||
identifier = f"{index + 1}.{name}.graph"
|
||||
shutil.copy(drawing_filename, output_directory.joinpath(f"{identifier}.png"))
|
||||
|
||||
textual_representations = self.textual_representations_of_graphs.items()
|
||||
for index, (name, representation) in enumerate(textual_representations):
|
||||
identifier = f"{index + 1}.{name}.graph"
|
||||
with open(output_directory.joinpath(f"{identifier}.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{representation}\n")
|
||||
|
||||
if self.bounds_of_the_final_graph is not None:
|
||||
assert self.final_graph is not None
|
||||
with open(output_directory.joinpath("bounds.txt"), "w", encoding="utf-8") as f:
|
||||
for index, node in enumerate(nx.topological_sort(self.final_graph.graph)):
|
||||
bounds = self.bounds_of_the_final_graph.get(node)
|
||||
f.write(f"%{index} :: [{bounds['min']}, {bounds['max']}]\n")
|
||||
|
||||
if self.mlir_to_compile is not None:
|
||||
assert self.final_graph is not None
|
||||
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{self.mlir_to_compile}\n")
|
||||
146
concrete/numpy/compilation/circuit.py
Normal file
146
concrete/numpy/compilation/circuit.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Declaration of `Circuit` class.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph
|
||||
from ..values import Value
|
||||
|
||||
|
||||
class Circuit:
|
||||
"""
|
||||
Circuit class, to combine computation graph and compiler engine into a single object.
|
||||
"""
|
||||
|
||||
graph: Graph
|
||||
engine: CompilerEngine
|
||||
|
||||
def __init__(self, graph: Graph, engine: CompilerEngine):
|
||||
self.graph = graph
|
||||
self.engine = engine
|
||||
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
|
||||
def draw(
|
||||
self,
|
||||
show: bool = False,
|
||||
horizontal: bool = False,
|
||||
save_to: Optional[Union[Path, str]] = None,
|
||||
) -> Path:
|
||||
"""
|
||||
Draw the `self.graph` and optionally save/show the drawing.
|
||||
|
||||
note that this function requires the python `pygraphviz` package
|
||||
which itself requires the installation of `graphviz` packages
|
||||
see https://pygraphviz.github.io/documentation/stable/install.html
|
||||
|
||||
Args:
|
||||
show (bool, default = False):
|
||||
whether to show the drawing using matplotlib or not
|
||||
|
||||
horizontal (bool, default = False):
|
||||
whether to draw horizontally or not
|
||||
|
||||
save_to (Optional[Path], default = None):
|
||||
path to save the drawing
|
||||
a temporary file will be used if it's None
|
||||
|
||||
Returns:
|
||||
Path:
|
||||
path to the saved drawing
|
||||
"""
|
||||
|
||||
return self.graph.draw(show, horizontal, save_to)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Union[int, np.ndarray],
|
||||
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
"""
|
||||
Encrypt inputs, evaluate the circuit, and decrypt the outputs in one go.
|
||||
|
||||
Args:
|
||||
*args (List[Union[int, numpy.ndarray]]):
|
||||
inputs to the engine
|
||||
|
||||
Returns:
|
||||
Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
|
||||
result of the homomorphic evaluation
|
||||
"""
|
||||
|
||||
if len(args) != len(self.graph.input_nodes):
|
||||
raise ValueError(f"Expected {len(self.graph.input_nodes)} inputs but got {len(args)}")
|
||||
|
||||
sanitized_args = {}
|
||||
|
||||
for index, node in self.graph.input_nodes.items():
|
||||
arg = args[index]
|
||||
is_valid = isinstance(arg, (int, np.integer)) or (
|
||||
isinstance(arg, np.ndarray) and np.issubdtype(arg.dtype, np.integer)
|
||||
)
|
||||
|
||||
expected_value = node.output
|
||||
|
||||
assert_that(isinstance(expected_value.dtype, Integer))
|
||||
expected_dtype = cast(Integer, expected_value.dtype)
|
||||
|
||||
if is_valid:
|
||||
expected_min = expected_dtype.min()
|
||||
expected_max = expected_dtype.max()
|
||||
expected_shape = expected_value.shape
|
||||
|
||||
actual_min = arg if isinstance(arg, int) else arg.min()
|
||||
actual_max = arg if isinstance(arg, int) else arg.max()
|
||||
actual_shape = () if isinstance(arg, int) else arg.shape
|
||||
|
||||
is_valid = (
|
||||
actual_min >= expected_min
|
||||
and actual_max <= expected_max
|
||||
and actual_shape == expected_shape
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
sanitized_args[index] = arg if isinstance(arg, int) else arg.astype(np.uint8)
|
||||
|
||||
if not is_valid:
|
||||
actual_value = Value.of(arg, is_encrypted=expected_value.is_encrypted)
|
||||
raise ValueError(
|
||||
f"Expected argument {index} to be {expected_value} but it's {actual_value}"
|
||||
)
|
||||
|
||||
results = self.engine.run(*[sanitized_args[i] for i in range(len(sanitized_args))])
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
|
||||
sanitized_results: List[Union[int, np.ndarray]] = []
|
||||
|
||||
for index, node in self.graph.output_nodes.items():
|
||||
expected_value = node.output
|
||||
assert_that(isinstance(expected_value.dtype, Integer))
|
||||
|
||||
expected_dtype = cast(Integer, expected_value.dtype)
|
||||
n = expected_dtype.bit_width
|
||||
|
||||
result = results[index] % (2 ** n)
|
||||
if expected_dtype.is_signed:
|
||||
if isinstance(result, int):
|
||||
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2 ** n)
|
||||
sanitized_results.append(sanititzed_result)
|
||||
else:
|
||||
result = result.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2 ** n))
|
||||
sanitized_results.append(sanititzed_result.astype(np.int8))
|
||||
else:
|
||||
sanitized_results.append(
|
||||
result if isinstance(result, int) else result.astype(np.uint8)
|
||||
)
|
||||
|
||||
return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results)
|
||||
352
concrete/numpy/compilation/compiler.py
Normal file
352
concrete/numpy/compilation/compiler.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Declaration of `Compiler` class.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilerEngine
|
||||
|
||||
from ..mlir import GraphConverter
|
||||
from ..representation import Graph
|
||||
from ..tracing import Tracer
|
||||
from ..values import Value
|
||||
from .artifacts import CompilationArtifacts
|
||||
from .circuit import Circuit
|
||||
from .configuration import CompilationConfiguration
|
||||
from .utils import fuse
|
||||
|
||||
|
||||
@unique
|
||||
class EncryptionStatus(str, Enum):
|
||||
"""
|
||||
EncryptionStatus enum, to represent encryption status of parameters.
|
||||
"""
|
||||
|
||||
CLEAR = "clear"
|
||||
ENCRYPTED = "encrypted"
|
||||
|
||||
|
||||
class Compiler:
|
||||
"""
|
||||
Compiler class, to glue the compilation pipeline.
|
||||
"""
|
||||
|
||||
function: Callable
|
||||
parameter_encryption_statuses: Dict[str, EncryptionStatus]
|
||||
|
||||
configuration: CompilationConfiguration
|
||||
artifacts: CompilationArtifacts
|
||||
|
||||
inputset: List[Any]
|
||||
graph: Optional[Graph]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: Callable,
|
||||
parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]],
|
||||
configuration: Optional[CompilationConfiguration] = None,
|
||||
artifacts: Optional[CompilationArtifacts] = None,
|
||||
):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
missing_args = list(signature.parameters)
|
||||
for arg in parameter_encryption_statuses.keys():
|
||||
if arg in signature.parameters:
|
||||
missing_args.remove(arg)
|
||||
|
||||
if len(missing_args) != 0:
|
||||
parameter_str = repr(missing_args[0])
|
||||
for arg in missing_args[1:-1]:
|
||||
parameter_str += f", {repr(arg)}"
|
||||
if len(missing_args) != 1:
|
||||
parameter_str += f" and {repr(missing_args[-1])}"
|
||||
|
||||
raise ValueError(
|
||||
f"Encryption status{'es' if len(missing_args) > 1 else ''} "
|
||||
f"of parameter{'s' if len(missing_args) > 1 else ''} "
|
||||
f"{parameter_str} of function '{function.__name__}' "
|
||||
f"{'are' if len(missing_args) > 1 else 'is'} not provided"
|
||||
)
|
||||
|
||||
additional_args = parameter_encryption_statuses.keys() - signature.parameters.keys()
|
||||
for arg in additional_args:
|
||||
del parameter_encryption_statuses[arg]
|
||||
|
||||
self.function = function # type: ignore
|
||||
self.parameter_encryption_statuses = {
|
||||
param: EncryptionStatus(status.lower())
|
||||
for param, status in parameter_encryption_statuses.items()
|
||||
}
|
||||
|
||||
self.configuration = (
|
||||
configuration if configuration is not None else CompilationConfiguration()
|
||||
)
|
||||
self.artifacts = artifacts if artifacts is not None else CompilationArtifacts()
|
||||
|
||||
self.inputset = []
|
||||
self.graph = None
|
||||
|
||||
self.artifacts.add_source_code(function)
|
||||
for param, encryption_status in parameter_encryption_statuses.items():
|
||||
self.artifacts.add_parameter_encryption_status(param, encryption_status)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
np.bool_,
|
||||
np.integer,
|
||||
np.floating,
|
||||
np.ndarray,
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
|
||||
]:
|
||||
if len(kwargs) != 0:
|
||||
raise RuntimeError(
|
||||
f"Calling function '{self.function.__name__}' with kwargs is not supported"
|
||||
)
|
||||
|
||||
sample = args[0] if len(args) == 1 else args
|
||||
|
||||
if self.graph is None:
|
||||
self._trace(sample)
|
||||
assert self.graph is not None
|
||||
|
||||
self.inputset.append(sample)
|
||||
return self.graph(*args)
|
||||
|
||||
def _trace(self, sample: Union[Any, Tuple[Any, ...]]):
|
||||
"""
|
||||
Trace the function and fuse the resulting graph with a sample input.
|
||||
|
||||
Args:
|
||||
sample (Union[Any, Tuple[Any, ...]]):
|
||||
sample to use for tracing
|
||||
"""
|
||||
|
||||
parameters = {
|
||||
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
|
||||
for arg, (param, status) in zip(
|
||||
sample if len(self.parameter_encryption_statuses) > 1 else (sample,),
|
||||
self.parameter_encryption_statuses.items(),
|
||||
)
|
||||
}
|
||||
|
||||
self.graph = Tracer.trace(self.function, parameters)
|
||||
self.artifacts.add_graph("initial", self.graph)
|
||||
|
||||
fuse(self.graph, self.artifacts)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
action: str,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]],
|
||||
):
|
||||
"""
|
||||
Trace, fuse, measure bounds, and update values in the resulting graph in one go.
|
||||
|
||||
Args:
|
||||
action (str):
|
||||
action being performed (e.g., "trace", "compile")
|
||||
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
"""
|
||||
|
||||
if inputset is not None:
|
||||
for input_ in inputset:
|
||||
self.inputset.append(input_)
|
||||
|
||||
if self.graph is None:
|
||||
try:
|
||||
first_sample = next(iter(self.inputset))
|
||||
except StopIteration as error:
|
||||
raise RuntimeError(
|
||||
f"{action} function '{self.function.__name__}' "
|
||||
f"without an inputset is not supported"
|
||||
) from error
|
||||
|
||||
self._trace(first_sample)
|
||||
assert self.graph is not None
|
||||
|
||||
bounds = self.graph.measure_bounds(self.inputset)
|
||||
self.artifacts.add_final_graph_bounds(bounds)
|
||||
|
||||
self.graph.update_with_bounds(bounds)
|
||||
self.artifacts.add_graph("final", self.graph)
|
||||
|
||||
def trace(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
) -> Graph:
|
||||
"""
|
||||
Trace the function using an inputset.
|
||||
|
||||
Args:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
show_graph (bool, default = False):
|
||||
whether to print the computation graph
|
||||
|
||||
Returns:
|
||||
Graph:
|
||||
computation graph representing the function prior to MLIR conversion
|
||||
"""
|
||||
|
||||
try:
|
||||
self._evaluate("Tracing", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
if show_graph:
|
||||
graph = self.graph.format()
|
||||
longest_line = max([len(line) for line in graph.split("\n")])
|
||||
|
||||
try: # pragma: no cover
|
||||
|
||||
# this branch cannot be covered
|
||||
# because `os.get_terminal_size()`
|
||||
# raises an exception during tests
|
||||
|
||||
columns, _ = os.get_terminal_size()
|
||||
if columns == 0:
|
||||
columns = min(longest_line, 80)
|
||||
else:
|
||||
columns = min(longest_line, columns)
|
||||
except OSError:
|
||||
columns = min(longest_line, 80)
|
||||
|
||||
print()
|
||||
|
||||
print("Computation Graph")
|
||||
print("-" * columns)
|
||||
print(graph)
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
return self.graph
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
# this branch is reserved for unexpected issues and hence it shouldn't be tested
|
||||
# if it could be tested, we would have fixed the underlying issue
|
||||
|
||||
# if the user desires so,
|
||||
# we need to export all the information we have about the compilation
|
||||
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures:
|
||||
self.artifacts.export()
|
||||
|
||||
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
|
||||
def compile(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
show_mlir: bool = False,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Compile the function using an inputset.
|
||||
|
||||
Args:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
show_graph (bool, default = False):
|
||||
whether to print the computation graph
|
||||
|
||||
show_mlir (bool, default = False):
|
||||
whether to print the compiled mlir
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
compiled circuit
|
||||
"""
|
||||
|
||||
try:
|
||||
self._evaluate("Compiling", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
mlir = GraphConverter.convert(self.graph)
|
||||
self.artifacts.add_mlir_to_compile(mlir)
|
||||
|
||||
if show_graph or show_mlir:
|
||||
graph = self.graph.format() if show_graph else ""
|
||||
|
||||
longest_graph_line = max([len(line) for line in graph.split("\n")])
|
||||
longest_mlir_line = max([len(line) for line in mlir.split("\n")])
|
||||
longest_line = max(longest_graph_line, longest_mlir_line)
|
||||
|
||||
try: # pragma: no cover
|
||||
|
||||
# this branch cannot be covered
|
||||
# because `os.get_terminal_size()`
|
||||
# raises an exception during tests
|
||||
|
||||
columns, _ = os.get_terminal_size()
|
||||
if columns == 0:
|
||||
columns = min(longest_line, 80)
|
||||
else:
|
||||
columns = min(longest_line, columns)
|
||||
except OSError:
|
||||
columns = min(longest_line, 80)
|
||||
|
||||
if show_graph:
|
||||
print()
|
||||
|
||||
print("Computation Graph")
|
||||
print("-" * columns)
|
||||
print(graph)
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
if show_mlir:
|
||||
print("\n" if not show_graph else "", end="")
|
||||
|
||||
print("MLIR")
|
||||
print("-" * columns)
|
||||
print(mlir)
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
engine = CompilerEngine()
|
||||
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
assert self.configuration.enable_unsafe_features
|
||||
location = CompilationConfiguration.insecure_key_cache_location()
|
||||
engine.compile_fhe(mlir, unsecure_key_set_cache_path=location)
|
||||
else:
|
||||
# this branch is not covered because all tests use key cache to speed up tests
|
||||
engine.compile_fhe(mlir) # pragma: no cover
|
||||
|
||||
return Circuit(self.graph, engine)
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
# this branch is reserved for unexpected issues and hence it shouldn't be tested
|
||||
# if it could be tested, we would have fixed the underlying issue
|
||||
|
||||
# if the user desires so,
|
||||
# we need to export all the information we have about the compilation
|
||||
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures:
|
||||
self.artifacts.export()
|
||||
|
||||
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
42
concrete/numpy/compilation/configuration.py
Normal file
42
concrete/numpy/compilation/configuration.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Declaration of `CompilationConfiguration` class.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
_INSECURE_KEY_CACHE_LOCATION: Optional[str] = None
|
||||
|
||||
|
||||
class CompilationConfiguration:
|
||||
"""
|
||||
CompilationConfiguration class, to allow the compilation process to be customized.
|
||||
"""
|
||||
|
||||
dump_artifacts_on_unexpected_failures: bool
|
||||
enable_unsafe_features: bool
|
||||
use_insecure_key_cache: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dump_artifacts_on_unexpected_failures: bool = True,
|
||||
enable_unsafe_features: bool = False,
|
||||
use_insecure_key_cache: bool = False,
|
||||
):
|
||||
self.dump_artifacts_on_unexpected_failures = dump_artifacts_on_unexpected_failures
|
||||
self.enable_unsafe_features = enable_unsafe_features
|
||||
self.use_insecure_key_cache = use_insecure_key_cache
|
||||
|
||||
if not enable_unsafe_features and use_insecure_key_cache:
|
||||
raise RuntimeError("Insecure key cache cannot be used without enabling unsafe features")
|
||||
|
||||
@staticmethod
|
||||
def insecure_key_cache_location() -> Optional[str]:
|
||||
"""
|
||||
Get insecure key cache location.
|
||||
|
||||
Returns:
|
||||
Optional[str]:
|
||||
insecure key cache location if configured, None otherwise
|
||||
"""
|
||||
|
||||
return _INSECURE_KEY_CACHE_LOCATION
|
||||
105
concrete/numpy/compilation/decorator.py
Normal file
105
concrete/numpy/compilation/decorator.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Declaration of `compiler` decorator.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Iterable, Mapping, Optional, Tuple, Union
|
||||
|
||||
from ..representation import Graph
|
||||
from .artifacts import CompilationArtifacts
|
||||
from .circuit import Circuit
|
||||
from .compiler import Compiler, EncryptionStatus
|
||||
from .configuration import CompilationConfiguration
|
||||
|
||||
|
||||
def compiler(
|
||||
parameters: Mapping[str, EncryptionStatus],
|
||||
configuration: Optional[CompilationConfiguration] = None,
|
||||
artifacts: Optional[CompilationArtifacts] = None,
|
||||
):
|
||||
"""
|
||||
Provide an easy interface for compilation.
|
||||
|
||||
Args:
|
||||
parameters (Dict[str, EncryptionStatus]):
|
||||
encryption statuses of the parameters of the function to compile
|
||||
|
||||
configuration(Optional[CompilationConfiguration], default = None):
|
||||
configuration to use for compilation
|
||||
|
||||
artifacts (Optional[CompilationArtifacts], default = None):
|
||||
artifacts to store information about compilation
|
||||
"""
|
||||
|
||||
def decoration(function: Callable):
|
||||
class Compilable:
|
||||
"""
|
||||
Compilable class, to wrap a function and provide methods to trace and compile it.
|
||||
"""
|
||||
|
||||
function: Callable
|
||||
compiler: Compiler
|
||||
|
||||
def __init__(self, function: Callable):
|
||||
self.function = function # type: ignore
|
||||
self.compiler = Compiler(
|
||||
self.function,
|
||||
dict(parameters),
|
||||
configuration,
|
||||
artifacts,
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
self.compiler(*args, **kwargs)
|
||||
return self.function(*args, **kwargs)
|
||||
|
||||
def trace(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
) -> Graph:
|
||||
"""
|
||||
Trace the function into computation graph.
|
||||
|
||||
Args:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
show_graph (bool, default = False):
|
||||
whether to print the computation graph
|
||||
|
||||
Returns:
|
||||
Graph:
|
||||
computation graph representing the function prior to MLIR conversion
|
||||
"""
|
||||
|
||||
return self.compiler.trace(inputset, show_graph)
|
||||
|
||||
def compile(
|
||||
self,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
||||
show_graph: bool = False,
|
||||
show_mlir: bool = False,
|
||||
) -> Circuit:
|
||||
"""
|
||||
Compile the function into a circuit.
|
||||
|
||||
Args:
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
|
||||
show_graph (bool, default = False):
|
||||
whether to print the computation graph
|
||||
|
||||
show_mlir (bool, default = False):
|
||||
whether to print the compiled mlir
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
compiled circuit
|
||||
"""
|
||||
|
||||
return self.compiler.compile(inputset, show_graph, show_mlir)
|
||||
|
||||
return Compilable(function)
|
||||
|
||||
return decoration
|
||||
395
concrete/numpy/compilation/utils.py
Normal file
395
concrete/numpy/compilation/utils.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Declaration of various functions and constants related to compilation.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..dtypes import Float, Integer
|
||||
from ..representation import Graph, Node, Operation
|
||||
from .artifacts import CompilationArtifacts
|
||||
|
||||
|
||||
def fuse(graph: Graph, artifacts: Optional[CompilationArtifacts] = None):
|
||||
"""
|
||||
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph to search and update
|
||||
|
||||
artifacts (Optional[CompilationArtifacts], default = None):
|
||||
compilation artifacts to store information about the fusing process
|
||||
"""
|
||||
|
||||
nx_graph = graph.graph
|
||||
processed_terminal_nodes: Set[Node] = set()
|
||||
|
||||
while True:
|
||||
float_subgraph_to_fuse = find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph,
|
||||
processed_terminal_nodes,
|
||||
)
|
||||
if float_subgraph_to_fuse is None:
|
||||
break
|
||||
|
||||
all_nodes, start_nodes, terminal_node = float_subgraph_to_fuse
|
||||
processed_terminal_nodes.add(terminal_node)
|
||||
|
||||
subgraph_conversion_result = convert_subgraph_to_subgraph_node(
|
||||
nx_graph,
|
||||
all_nodes,
|
||||
start_nodes,
|
||||
terminal_node,
|
||||
)
|
||||
if subgraph_conversion_result is None:
|
||||
continue
|
||||
|
||||
fused_node, node_before_subgraph = subgraph_conversion_result
|
||||
nx_graph.add_node(fused_node)
|
||||
|
||||
if terminal_node in graph.output_nodes.values():
|
||||
output_node_to_idx: Dict[Node, List[int]] = {
|
||||
out_node: [] for out_node in graph.output_nodes.values()
|
||||
}
|
||||
for output_idx, output_node in graph.output_nodes.items():
|
||||
output_node_to_idx[output_node].append(output_idx)
|
||||
|
||||
for output_idx in output_node_to_idx.get(terminal_node, []):
|
||||
graph.output_nodes[output_idx] = fused_node
|
||||
|
||||
terminal_node_succ = list(nx_graph.successors(terminal_node))
|
||||
for succ in terminal_node_succ:
|
||||
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
|
||||
for edge_key, edge_data in succ_edge_data.items():
|
||||
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
|
||||
new_edge_data = deepcopy(edge_data)
|
||||
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
|
||||
|
||||
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
|
||||
|
||||
graph.prune_useless_nodes()
|
||||
if artifacts is not None:
|
||||
artifacts.add_graph("after-fusing-float-operations", graph)
|
||||
|
||||
|
||||
def find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
processed_terminal_nodes: Set[Node],
|
||||
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
||||
"""
|
||||
Find a subgraph with float computations that end with an integer output.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph to search
|
||||
|
||||
processed_terminal_nodes (Set[Node]):
|
||||
set of terminal nodes which have already been searched for float subgraphs
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
||||
None if there are no such subgraphs,
|
||||
tuple containing all nodes in the subgraph, start nodes of the subgraph,
|
||||
and terminal node of the subgraph otherwise
|
||||
"""
|
||||
|
||||
terminal_nodes = (
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
if (
|
||||
node not in processed_terminal_nodes
|
||||
and any(isinstance(input.dtype, Float) for input in node.inputs)
|
||||
and isinstance(node.output.dtype, Integer)
|
||||
)
|
||||
)
|
||||
try:
|
||||
terminal_node = next(terminal_nodes)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
# networkx does not implement lowest common ancestor search for multidigraph, but we only care
|
||||
# about parent relationship here and not the meaning of edges, so we can convert our
|
||||
# multidigraph to a digraph and use the lca search algorithm (if needed), we create the
|
||||
# equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause
|
||||
# issues in our search, so we remove them.
|
||||
equivalent_subgraph_without_constants = nx.DiGraph(nx_graph)
|
||||
constant_nodes = [
|
||||
node
|
||||
for node in equivalent_subgraph_without_constants.nodes()
|
||||
if node.operation == Operation.Constant
|
||||
]
|
||||
equivalent_subgraph_without_constants.remove_nodes_from(constant_nodes)
|
||||
|
||||
all_nodes: Dict[Node, None] = {}
|
||||
|
||||
start_single_int_output_nodes_search_from = terminal_node
|
||||
while True:
|
||||
all_nodes, start_nodes = find_closest_integer_output_nodes(
|
||||
nx_graph,
|
||||
[start_single_int_output_nodes_search_from],
|
||||
all_nodes,
|
||||
)
|
||||
|
||||
variable_start_nodes = [
|
||||
start_node for start_node in start_nodes if start_node.operation != Operation.Constant
|
||||
]
|
||||
if len(variable_start_nodes) == 1:
|
||||
break
|
||||
|
||||
# find a common ancestor as we need a single variable input node
|
||||
# lca == lowest common ancestor
|
||||
# lca search only works for node pairs in networkx, so we progressively find the ancestors
|
||||
# setting the lca by default to one of the nodes we are searching the lca for
|
||||
lca = variable_start_nodes.pop()
|
||||
while len(variable_start_nodes) > 0 and lca is not None:
|
||||
node_to_find_new_lca = variable_start_nodes.pop()
|
||||
|
||||
ancestors_of_lca = nx.ancestors(
|
||||
equivalent_subgraph_without_constants,
|
||||
lca,
|
||||
)
|
||||
ancestors_of_node_to_find_new_lca = nx.ancestors(
|
||||
equivalent_subgraph_without_constants,
|
||||
node_to_find_new_lca,
|
||||
)
|
||||
|
||||
lca_is_ancestor_of_node_to_find_new_lca = lca in ancestors_of_node_to_find_new_lca
|
||||
node_to_find_new_lca_is_ancestor_of_lca = node_to_find_new_lca in ancestors_of_lca
|
||||
|
||||
if lca_is_ancestor_of_node_to_find_new_lca or node_to_find_new_lca_is_ancestor_of_lca:
|
||||
lca = lca if lca_is_ancestor_of_node_to_find_new_lca else node_to_find_new_lca
|
||||
continue
|
||||
|
||||
lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor(
|
||||
equivalent_subgraph_without_constants, lca, node_to_find_new_lca, default=None
|
||||
)
|
||||
|
||||
# if subgraph cannot be fused because there is no way to find a common ancestor, break
|
||||
if lca is None:
|
||||
break
|
||||
|
||||
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
|
||||
all_nodes = add_nodes_from_to(nx_graph, start_nodes, {lca: None}, all_nodes)
|
||||
|
||||
# if `lca` is a valid starting node for fusing break
|
||||
if isinstance(lca.output.dtype, Integer):
|
||||
# `lca` is the new start node
|
||||
start_nodes = {lca: None}
|
||||
break
|
||||
|
||||
# otherwise, push a little further
|
||||
# (e.g., if there is a node just before, which has an integer output)
|
||||
start_single_int_output_nodes_search_from = lca
|
||||
|
||||
return all_nodes, start_nodes, terminal_node
|
||||
|
||||
|
||||
def find_closest_integer_output_nodes(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
start_nodes: List[Node],
|
||||
all_nodes: Dict[Node, None],
|
||||
) -> Tuple[Dict[Node, None], Dict[Node, None]]:
|
||||
"""
|
||||
Find the closest upstream integer output nodes to a set of start nodes in a graph.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph to search
|
||||
|
||||
start_nodes (List[Node]):
|
||||
nodes from which to start the search
|
||||
|
||||
all_nodes (Dict[Node, None]):
|
||||
set of nodes to be extended with visited nodes during the search
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[Node, None], Dict[Node, None]]:
|
||||
tuple containing extended `all_nodes` and integer output nodes closest to `start_nodes`
|
||||
"""
|
||||
|
||||
closest_integer_output_nodes: Dict[Node, None] = {}
|
||||
visited_nodes: Set[Node] = set()
|
||||
|
||||
current_nodes = {start_node: None for start_node in start_nodes}
|
||||
while current_nodes:
|
||||
next_nodes: Dict[Node, None] = {}
|
||||
for node in current_nodes:
|
||||
if node not in visited_nodes:
|
||||
visited_nodes.add(node)
|
||||
|
||||
all_nodes.update({node: None})
|
||||
for pred in nx_graph.predecessors(node):
|
||||
if isinstance(pred.output.dtype, Integer):
|
||||
closest_integer_output_nodes.update({pred: None})
|
||||
all_nodes.update({pred: None})
|
||||
else:
|
||||
next_nodes.update({pred: None})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return all_nodes, closest_integer_output_nodes
|
||||
|
||||
|
||||
def add_nodes_from_to(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
from_nodes: Iterable[Node],
|
||||
to_nodes: Dict[Node, None],
|
||||
all_nodes: Dict[Node, None],
|
||||
) -> Dict[Node, None]:
|
||||
"""
|
||||
Add nodes from `from_nodes` to `to_nodes`, to `all_nodes`.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph to traverse
|
||||
|
||||
from_nodes (Iterable[Node]):
|
||||
nodes from which extending `all_nodes` start
|
||||
|
||||
to_nodes (Dict[Node, None]):
|
||||
nodes to which extending `all_nodes` stop
|
||||
|
||||
all_nodes (Dict[Node, None]):
|
||||
nodes to be extended
|
||||
|
||||
Returns:
|
||||
Dict[Node, None]:
|
||||
extended `all_nodes`
|
||||
"""
|
||||
|
||||
all_nodes.update(to_nodes)
|
||||
visited_nodes: Set[Node] = set()
|
||||
|
||||
current_nodes = {from_node: None for from_node in from_nodes}
|
||||
while current_nodes:
|
||||
next_nodes: Dict[Node, None] = {}
|
||||
for node in current_nodes:
|
||||
if node not in visited_nodes:
|
||||
visited_nodes.add(node)
|
||||
|
||||
all_nodes.update({node: None})
|
||||
if node not in to_nodes:
|
||||
predecessors = nx_graph.predecessors(node)
|
||||
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return all_nodes
|
||||
|
||||
|
||||
def convert_subgraph_to_subgraph_node(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
all_nodes: Dict[Node, None],
|
||||
start_nodes: Dict[Node, None],
|
||||
terminal_node: Node,
|
||||
) -> Optional[Tuple[Node, Node]]:
|
||||
"""
|
||||
Convert a subgraph to Operation.Generic node.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
orginal networkx graph
|
||||
|
||||
all_nodes (Dict[Node, None]):
|
||||
all nodes in the subgraph
|
||||
|
||||
start_nodes (Dict[Node, None]):
|
||||
start nodes of the subgraph
|
||||
|
||||
terminal_node (Node):
|
||||
terminal node of the subgraph
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[Node, Node]]:
|
||||
None if the subgraph cannot be fused,
|
||||
subgraph node and its predecessor otherwise
|
||||
"""
|
||||
|
||||
variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant]
|
||||
if len(variable_input_nodes) != 1:
|
||||
return None
|
||||
|
||||
variable_input_node = variable_input_nodes[0]
|
||||
if not subgraph_can_be_fused(all_nodes, variable_input_node):
|
||||
return None
|
||||
|
||||
nx_subgraph = nx.MultiDiGraph(nx_graph)
|
||||
nodes_to_remove = [node for node in nx_subgraph.nodes() if node not in all_nodes]
|
||||
nx_subgraph.remove_nodes_from(nodes_to_remove)
|
||||
|
||||
subgraph_variable_input_node = Node.input("input", deepcopy(variable_input_node.output))
|
||||
nx_subgraph.add_node(subgraph_variable_input_node)
|
||||
|
||||
variable_input_node_successors = {
|
||||
node: None for node in all_nodes if node in nx_graph.succ[variable_input_node]
|
||||
}
|
||||
for successor in variable_input_node_successors:
|
||||
edges = deepcopy(nx_subgraph.get_edge_data(variable_input_node, successor))
|
||||
for edge_key, edge_data in edges.items():
|
||||
nx_subgraph.remove_edge(variable_input_node, successor, key=edge_key)
|
||||
new_edge_data = deepcopy(edge_data)
|
||||
nx_subgraph.add_edge(
|
||||
subgraph_variable_input_node,
|
||||
successor,
|
||||
key=edge_key,
|
||||
**new_edge_data,
|
||||
)
|
||||
|
||||
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
|
||||
subgraph_node = Node.generic(
|
||||
"subgraph",
|
||||
subgraph_variable_input_node.inputs,
|
||||
terminal_node.output,
|
||||
lambda x, subgraph, terminal_node: subgraph.evaluate(x)[terminal_node],
|
||||
kwargs={
|
||||
"subgraph": subgraph,
|
||||
"terminal_node": terminal_node,
|
||||
},
|
||||
)
|
||||
|
||||
return subgraph_node, variable_input_node
|
||||
|
||||
|
||||
def subgraph_can_be_fused(
|
||||
all_nodes: Dict[Node, None],
|
||||
variable_input_node: Node,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a subgraph can be fused.
|
||||
|
||||
e.g.,
|
||||
|
||||
shuffling or reshaping a tensor make fusing impossible as there should be a one-to-one mapping
|
||||
between each cell of the input and each cell of the output for table lookups
|
||||
|
||||
Args:
|
||||
all_nodes (Dict[Node, None]):
|
||||
all nodes in the subgraph
|
||||
|
||||
variable_input_node (Node):
|
||||
variable input node to the subgraph
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
True if subgraph can be fused,
|
||||
False otherwise
|
||||
"""
|
||||
|
||||
constant_nodes_with_bigger_size_than_variable_input = [
|
||||
node
|
||||
for node in all_nodes
|
||||
if (
|
||||
node.operation == Operation.Constant
|
||||
and node.output.size > variable_input_node.output.size
|
||||
)
|
||||
]
|
||||
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
|
||||
return False
|
||||
|
||||
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
|
||||
for node in non_constant_nodes:
|
||||
if node.output.shape != variable_input_node.output.shape:
|
||||
return False
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user