mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
475 lines
16 KiB
Python
475 lines
16 KiB
Python
"""
|
|
Declaration of `Compiler` class.
|
|
"""
|
|
|
|
import inspect
|
|
import os
|
|
import traceback
|
|
from copy import deepcopy
|
|
from enum import Enum, unique
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
|
|
from ..mlir import GraphConverter
|
|
from ..representation import Graph
|
|
from ..tracing import Tracer
|
|
from ..values import Value
|
|
from .artifacts import DebugArtifacts
|
|
from .circuit import Circuit
|
|
from .configuration import Configuration
|
|
from .utils import fuse
|
|
|
|
|
|
@unique
|
|
class EncryptionStatus(str, Enum):
|
|
"""
|
|
EncryptionStatus enum, to represent encryption status of parameters.
|
|
"""
|
|
|
|
CLEAR = "clear"
|
|
ENCRYPTED = "encrypted"
|
|
|
|
|
|
class Compiler:
|
|
"""
|
|
Compiler class, to glue the compilation pipeline.
|
|
"""
|
|
|
|
function: Callable
|
|
parameter_encryption_statuses: Dict[str, EncryptionStatus]
|
|
|
|
configuration: Configuration
|
|
artifacts: Optional[DebugArtifacts]
|
|
|
|
inputset: List[Any]
|
|
graph: Optional[Graph]
|
|
|
|
def __init__(
|
|
self,
|
|
function: Callable,
|
|
parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]],
|
|
):
|
|
signature = inspect.signature(function)
|
|
|
|
missing_args = list(signature.parameters)
|
|
for arg in parameter_encryption_statuses.keys():
|
|
if arg in signature.parameters:
|
|
missing_args.remove(arg)
|
|
|
|
if len(missing_args) != 0:
|
|
parameter_str = repr(missing_args[0])
|
|
for arg in missing_args[1:-1]:
|
|
parameter_str += f", {repr(arg)}"
|
|
if len(missing_args) != 1:
|
|
parameter_str += f" and {repr(missing_args[-1])}"
|
|
|
|
raise ValueError(
|
|
f"Encryption status{'es' if len(missing_args) > 1 else ''} "
|
|
f"of parameter{'s' if len(missing_args) > 1 else ''} "
|
|
f"{parameter_str} of function '{function.__name__}' "
|
|
f"{'are' if len(missing_args) > 1 else 'is'} not provided"
|
|
)
|
|
|
|
additional_args = list(parameter_encryption_statuses)
|
|
for arg in signature.parameters.keys():
|
|
if arg in parameter_encryption_statuses:
|
|
additional_args.remove(arg)
|
|
|
|
if len(additional_args) != 0:
|
|
parameter_str = repr(additional_args[0])
|
|
for arg in additional_args[1:-1]:
|
|
parameter_str += f", {repr(arg)}"
|
|
if len(additional_args) != 1:
|
|
parameter_str += f" and {repr(additional_args[-1])}"
|
|
|
|
raise ValueError(
|
|
f"Encryption status{'es' if len(additional_args) > 1 else ''} "
|
|
f"of {parameter_str} {'are' if len(additional_args) > 1 else 'is'} provided but "
|
|
f"{'they are' if len(additional_args) > 1 else 'it is'} not a parameter "
|
|
f"of function '{function.__name__}'"
|
|
)
|
|
|
|
self.function = function # type: ignore
|
|
self.parameter_encryption_statuses = {
|
|
param: EncryptionStatus(status.lower())
|
|
for param, status in parameter_encryption_statuses.items()
|
|
}
|
|
|
|
self.configuration = Configuration()
|
|
self.artifacts = None
|
|
|
|
self.inputset = []
|
|
self.graph = None
|
|
|
|
def __call__(
|
|
self,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Union[
|
|
np.bool_,
|
|
np.integer,
|
|
np.floating,
|
|
np.ndarray,
|
|
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
|
|
]:
|
|
if len(kwargs) != 0:
|
|
raise RuntimeError(
|
|
f"Calling function '{self.function.__name__}' with kwargs is not supported"
|
|
)
|
|
|
|
sample = args[0] if len(args) == 1 else args
|
|
|
|
if self.graph is None:
|
|
self._trace(sample)
|
|
assert self.graph is not None
|
|
|
|
self.inputset.append(sample)
|
|
return self.graph(*args)
|
|
|
|
def _trace(self, sample: Union[Any, Tuple[Any, ...]]):
|
|
"""
|
|
Trace the function and fuse the resulting graph with a sample input.
|
|
|
|
Args:
|
|
sample (Union[Any, Tuple[Any, ...]]):
|
|
sample to use for tracing
|
|
"""
|
|
|
|
if self.artifacts is not None:
|
|
self.artifacts.add_source_code(self.function)
|
|
for param, encryption_status in self.parameter_encryption_statuses.items():
|
|
self.artifacts.add_parameter_encryption_status(param, encryption_status)
|
|
|
|
parameters = {
|
|
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
|
|
for arg, (param, status) in zip(
|
|
sample if len(self.parameter_encryption_statuses) > 1 else (sample,),
|
|
self.parameter_encryption_statuses.items(),
|
|
)
|
|
}
|
|
|
|
self.graph = Tracer.trace(self.function, parameters)
|
|
if self.artifacts is not None:
|
|
self.artifacts.add_graph("initial", self.graph)
|
|
|
|
fuse(self.graph, self.artifacts)
|
|
|
|
def _evaluate(
|
|
self,
|
|
action: str,
|
|
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]],
|
|
):
|
|
"""
|
|
Trace, fuse, measure bounds, and update values in the resulting graph in one go.
|
|
|
|
Args:
|
|
action (str):
|
|
action being performed (e.g., "trace", "compile")
|
|
|
|
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
|
optional inputset to extend accumulated inputset before bounds measurement
|
|
"""
|
|
|
|
if inputset is not None:
|
|
previous_inputset_length = len(self.inputset)
|
|
for index, sample in enumerate(iter(inputset)):
|
|
self.inputset.append(sample)
|
|
|
|
if not isinstance(sample, tuple):
|
|
sample = (sample,)
|
|
|
|
if len(sample) != len(self.parameter_encryption_statuses):
|
|
self.inputset = self.inputset[:previous_inputset_length]
|
|
|
|
expected = (
|
|
"a single value"
|
|
if len(self.parameter_encryption_statuses) == 1
|
|
else f"a tuple of {len(self.parameter_encryption_statuses)} values"
|
|
)
|
|
actual = (
|
|
"a single value" if len(sample) == 1 else f"a tuple of {len(sample)} values"
|
|
)
|
|
|
|
raise ValueError(
|
|
f"Input #{index} of your inputset is not well formed "
|
|
f"(expected {expected} got {actual})"
|
|
)
|
|
|
|
if self.graph is None:
|
|
try:
|
|
first_sample = next(iter(self.inputset))
|
|
except StopIteration as error:
|
|
raise RuntimeError(
|
|
f"{action} function '{self.function.__name__}' "
|
|
f"without an inputset is not supported"
|
|
) from error
|
|
|
|
self._trace(first_sample)
|
|
assert self.graph is not None
|
|
|
|
bounds = self.graph.measure_bounds(self.inputset)
|
|
if self.artifacts is not None:
|
|
self.artifacts.add_final_graph_bounds(bounds)
|
|
|
|
self.graph.update_with_bounds(bounds)
|
|
if self.artifacts is not None:
|
|
self.artifacts.add_graph("final", self.graph)
|
|
|
|
def trace(
|
|
self,
|
|
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
|
configuration: Optional[Configuration] = None,
|
|
artifacts: Optional[DebugArtifacts] = None,
|
|
**kwargs,
|
|
) -> Graph:
|
|
"""
|
|
Trace the function using an inputset.
|
|
|
|
Args:
|
|
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
|
optional inputset to extend accumulated inputset before bounds measurement
|
|
|
|
configuration(Optional[Configuration], default = None):
|
|
configuration to use
|
|
|
|
artifacts (Optional[DebugArtifacts], default = None):
|
|
artifacts to store information about the process
|
|
|
|
kwargs (Dict[str, Any]):
|
|
configuration options to overwrite
|
|
|
|
Returns:
|
|
Graph:
|
|
computation graph representing the function prior to MLIR conversion
|
|
"""
|
|
|
|
old_configuration = deepcopy(self.configuration)
|
|
old_artifacts = deepcopy(self.artifacts)
|
|
|
|
if configuration is not None:
|
|
self.configuration = configuration
|
|
|
|
if len(kwargs) != 0:
|
|
self.configuration = self.configuration.fork(**kwargs)
|
|
|
|
self.artifacts = (
|
|
artifacts
|
|
if artifacts is not None
|
|
else DebugArtifacts()
|
|
if self.configuration.dump_artifacts_on_unexpected_failures
|
|
else None
|
|
)
|
|
|
|
try:
|
|
|
|
self._evaluate("Tracing", inputset)
|
|
assert self.graph is not None
|
|
|
|
if self.configuration.verbose or self.configuration.show_graph:
|
|
graph = self.graph.format()
|
|
longest_line = max([len(line) for line in graph.split("\n")])
|
|
|
|
try: # pragma: no cover
|
|
|
|
# this branch cannot be covered
|
|
# because `os.get_terminal_size()`
|
|
# raises an exception during tests
|
|
|
|
columns, _ = os.get_terminal_size()
|
|
if columns == 0:
|
|
columns = min(longest_line, 80)
|
|
else:
|
|
columns = min(longest_line, columns)
|
|
except OSError: # pragma: no cover
|
|
columns = min(longest_line, 80)
|
|
|
|
print()
|
|
|
|
print("Computation Graph")
|
|
print("-" * columns)
|
|
print(graph)
|
|
print("-" * columns)
|
|
|
|
print()
|
|
|
|
return self.graph
|
|
|
|
except Exception: # pragma: no cover
|
|
|
|
# this branch is reserved for unexpected issues and hence it shouldn't be tested
|
|
# if it could be tested, we would have fixed the underlying issue
|
|
|
|
# if the user desires so,
|
|
# we need to export all the information we have about the compilation
|
|
|
|
if self.configuration.dump_artifacts_on_unexpected_failures:
|
|
assert self.artifacts is not None
|
|
self.artifacts.export()
|
|
|
|
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
|
with open(traceback_path, "w", encoding="utf-8") as f:
|
|
f.write(traceback.format_exc())
|
|
|
|
raise
|
|
|
|
finally:
|
|
|
|
self.configuration = old_configuration
|
|
self.artifacts = old_artifacts
|
|
|
|
# pylint: disable=too-many-branches,too-many-statements
|
|
|
|
def compile(
|
|
self,
|
|
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None,
|
|
configuration: Optional[Configuration] = None,
|
|
artifacts: Optional[DebugArtifacts] = None,
|
|
**kwargs,
|
|
) -> Circuit:
|
|
"""
|
|
Compile the function using an inputset.
|
|
|
|
Args:
|
|
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
|
optional inputset to extend accumulated inputset before bounds measurement
|
|
|
|
configuration(Optional[Configuration], default = None):
|
|
configuration to use
|
|
|
|
artifacts (Optional[DebugArtifacts], default = None):
|
|
artifacts to store information about the process
|
|
|
|
kwargs (Dict[str, Any]):
|
|
configuration options to overwrite
|
|
|
|
Returns:
|
|
Circuit:
|
|
compiled circuit
|
|
"""
|
|
|
|
old_configuration = deepcopy(self.configuration)
|
|
old_artifacts = deepcopy(self.artifacts)
|
|
|
|
if configuration is not None:
|
|
self.configuration = configuration
|
|
|
|
if len(kwargs) != 0:
|
|
self.configuration = self.configuration.fork(**kwargs)
|
|
|
|
self.artifacts = (
|
|
artifacts
|
|
if artifacts is not None
|
|
else DebugArtifacts()
|
|
if self.configuration.dump_artifacts_on_unexpected_failures
|
|
else None
|
|
)
|
|
|
|
try:
|
|
|
|
self._evaluate("Compiling", inputset)
|
|
assert self.graph is not None
|
|
|
|
mlir = GraphConverter.convert(self.graph, virtual=self.configuration.virtual)
|
|
if self.artifacts is not None:
|
|
self.artifacts.add_mlir_to_compile(mlir)
|
|
|
|
show_graph = self.configuration.verbose or self.configuration.show_graph
|
|
show_mlir = self.configuration.verbose or self.configuration.show_mlir
|
|
show_optimizer = self.configuration.verbose or self.configuration.show_optimizer
|
|
|
|
columns = 0
|
|
if show_graph or show_mlir or show_optimizer:
|
|
|
|
graph = (
|
|
self.graph.format()
|
|
if self.configuration.verbose or self.configuration.show_graph
|
|
else ""
|
|
)
|
|
|
|
longest_graph_line = max([len(line) for line in graph.split("\n")])
|
|
longest_mlir_line = max([len(line) for line in mlir.split("\n")])
|
|
longest_line = max(longest_graph_line, longest_mlir_line)
|
|
|
|
try: # pragma: no cover
|
|
|
|
# this branch cannot be covered
|
|
# because `os.get_terminal_size()`
|
|
# raises an exception during tests
|
|
|
|
columns, _ = os.get_terminal_size()
|
|
if columns == 0:
|
|
columns = min(longest_line, 80)
|
|
else:
|
|
columns = min(longest_line, columns)
|
|
except OSError: # pragma: no cover
|
|
columns = min(longest_line, 80)
|
|
|
|
if show_graph:
|
|
print()
|
|
|
|
print("Computation Graph")
|
|
print("-" * columns)
|
|
print(graph)
|
|
print("-" * columns)
|
|
|
|
print()
|
|
|
|
if show_mlir:
|
|
print("\n" if not show_graph else "", end="")
|
|
|
|
print("MLIR")
|
|
print("-" * columns)
|
|
print(mlir)
|
|
print("-" * columns)
|
|
|
|
print()
|
|
|
|
if show_optimizer:
|
|
print("\n" if not (show_graph or show_mlir) else "", end="")
|
|
|
|
print("Optimizer")
|
|
print("-" * columns)
|
|
|
|
circuit = Circuit(self.graph, mlir, self.configuration)
|
|
if not self.configuration.virtual:
|
|
assert circuit.client.specs.client_parameters is not None
|
|
if self.artifacts is not None:
|
|
self.artifacts.add_client_parameters(
|
|
circuit.client.specs.client_parameters.serialize()
|
|
)
|
|
|
|
if show_optimizer:
|
|
if self.configuration.virtual:
|
|
print("Virtual circuits doesn't have optimizer output.")
|
|
|
|
print("-" * columns)
|
|
print()
|
|
|
|
return circuit
|
|
|
|
except Exception: # pragma: no cover
|
|
|
|
# this branch is reserved for unexpected issues and hence it shouldn't be tested
|
|
# if it could be tested, we would have fixed the underlying issue
|
|
|
|
# if the user desires so,
|
|
# we need to export all the information we have about the compilation
|
|
|
|
if self.configuration.dump_artifacts_on_unexpected_failures:
|
|
assert self.artifacts is not None
|
|
self.artifacts.export()
|
|
|
|
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
|
with open(traceback_path, "w", encoding="utf-8") as f:
|
|
f.write(traceback.format_exc())
|
|
|
|
raise
|
|
|
|
finally:
|
|
|
|
self.configuration = old_configuration
|
|
self.artifacts = old_artifacts
|
|
|
|
# pylint: enable=too-many-branches,too-many-statements
|