mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-10 05:18:00 -05:00
refactor(frontends): unify circuits and modules
This commit is contained in:
committed by
Alexandre Péré
parent
52636e47c6
commit
d9b34f13d0
@@ -1310,6 +1310,15 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
clientParameters, inputId, circuitName);
|
||||
return encryption.getVariance();
|
||||
})
|
||||
.def("function_list",
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
std::vector<std::string> result;
|
||||
for (auto circuit :
|
||||
clientParameters.programInfo.asReader().getCircuits()) {
|
||||
result.push_back(circuit.getName());
|
||||
}
|
||||
return result;
|
||||
})
|
||||
.def("output_signs",
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
std::vector<bool> result;
|
||||
|
||||
@@ -37,13 +37,12 @@ class ClientParameters(WrapperCpp):
|
||||
)
|
||||
super().__init__(client_parameters)
|
||||
|
||||
def input_keyid_at(self, input_idx: int, circuit_name: str = "main") -> int:
|
||||
def input_keyid_at(self, input_idx: int, circuit_name: str) -> int:
|
||||
"""Get the keyid of a selected encrypted input in a given circuit.
|
||||
|
||||
Args:
|
||||
input_idx (int): index of the input in the circuit.
|
||||
circuit_name (str, optional): name of the circuit containing the desired input.
|
||||
Defaults to "main".
|
||||
circuit_name (str): name of the circuit containing the desired input.
|
||||
|
||||
Raises:
|
||||
TypeError: if arguments aren't of expected types
|
||||
@@ -59,13 +58,12 @@ class ClientParameters(WrapperCpp):
|
||||
)
|
||||
return self.cpp().input_keyid_at(input_idx, circuit_name)
|
||||
|
||||
def input_variance_at(self, input_idx: int, circuit_name: str = "main") -> float:
|
||||
def input_variance_at(self, input_idx: int, circuit_name: str) -> float:
|
||||
"""Get the variance of a selected encrypted input in a given circuit.
|
||||
|
||||
Args:
|
||||
input_idx (int): index of the input in the circuit.
|
||||
circuit_name (str, optional): name of the circuit containing the desired input.
|
||||
Defaults to "main".
|
||||
circuit_name (str): name of the circuit containing the desired input.
|
||||
|
||||
Raises:
|
||||
TypeError: if arguments aren't of expected types
|
||||
@@ -97,6 +95,14 @@ class ClientParameters(WrapperCpp):
|
||||
"""
|
||||
return self.cpp().output_signs()
|
||||
|
||||
def function_list(self) -> List[str]:
|
||||
"""Return the list of function names.
|
||||
|
||||
Returns:
|
||||
List[str]: list of the names of the functions.
|
||||
"""
|
||||
return self.cpp().function_list()
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Serialize the ClientParameters.
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ class ClientSupport(WrapperCpp):
|
||||
client_parameters: ClientParameters,
|
||||
keyset: KeySet,
|
||||
args: List[Union[int, np.ndarray]],
|
||||
circuit_name: str = "main",
|
||||
circuit_name: str,
|
||||
) -> PublicArguments:
|
||||
"""Prepare arguments for encrypted computation.
|
||||
|
||||
@@ -172,7 +172,7 @@ class ClientSupport(WrapperCpp):
|
||||
client_parameters: ClientParameters,
|
||||
keyset: KeySet,
|
||||
public_result: PublicResult,
|
||||
circuit_name: str = "main",
|
||||
circuit_name: str,
|
||||
) -> Union[int, np.ndarray]:
|
||||
"""Decrypt a public result using the keyset.
|
||||
|
||||
|
||||
@@ -236,7 +236,7 @@ class LibrarySupport(WrapperCpp):
|
||||
self,
|
||||
library_compilation_result: LibraryCompilationResult,
|
||||
simulation: bool,
|
||||
circuit_name: str = "main",
|
||||
circuit_name: str,
|
||||
) -> LibraryLambda:
|
||||
"""Load the server lambda for a given circuit from the library compilation result.
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class SimulatedValueDecrypter(WrapperCpp):
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(client_parameters: ClientParameters, circuit_name: str = "main"):
|
||||
def new(client_parameters: ClientParameters, circuit_name: str):
|
||||
"""
|
||||
Create a value decrypter.
|
||||
"""
|
||||
|
||||
@@ -41,7 +41,7 @@ class SimulatedValueExporter(WrapperCpp):
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(
|
||||
client_parameters: ClientParameters, circuitName: str = "main"
|
||||
client_parameters: ClientParameters, circuitName: str
|
||||
) -> "SimulatedValueExporter":
|
||||
"""
|
||||
Create a value exporter.
|
||||
|
||||
@@ -42,7 +42,7 @@ class ValueExporter(WrapperCpp):
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(
|
||||
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main"
|
||||
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str
|
||||
) -> "ValueExporter":
|
||||
"""
|
||||
Create a value exporter.
|
||||
|
||||
@@ -24,7 +24,7 @@ def assert_result(result, expected_result):
|
||||
assert np.all(result == expected_result)
|
||||
|
||||
|
||||
def run(engine, args, compilation_result, keyset_cache, circuit_name="main"):
|
||||
def run(engine, args, compilation_result, keyset_cache, circuit_name):
|
||||
"""Execute engine on the given arguments.
|
||||
|
||||
Perform required loading, encryption, execution, and decryption."""
|
||||
@@ -219,7 +219,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_ca
|
||||
# Here don't save compilation result, reload
|
||||
engine.compile(mlir_input)
|
||||
compilation_result = engine.reload()
|
||||
result = run(engine, args, compilation_result, keyset_cache)
|
||||
result = run(engine, args, compilation_result, keyset_cache, "main")
|
||||
# Check result
|
||||
assert_result(result, expected_result)
|
||||
shutil.rmtree(artifact_dir)
|
||||
@@ -398,7 +398,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache):
|
||||
|
||||
def test_crt_decomposition_feedback():
|
||||
mlir = """
|
||||
|
||||
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%tlu = arith.constant dense<60000> : tensor<65536xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<16>, tensor<65536xi64>) -> (!FHE.eint<16>)
|
||||
|
||||
@@ -5,7 +5,7 @@ Glue the compilation process together.
|
||||
from .artifacts import DebugArtifacts, FunctionDebugArtifacts, ModuleDebugArtifacts
|
||||
from .circuit import Circuit
|
||||
from .client import Client
|
||||
from .compiler import Compiler, EncryptionStatus
|
||||
from .compiler import Compiler
|
||||
from .composition import CompositionClause, CompositionPolicy, CompositionRule
|
||||
from .configuration import (
|
||||
DEFAULT_GLOBAL_P_ERROR,
|
||||
@@ -22,19 +22,10 @@ from .configuration import (
|
||||
)
|
||||
from .keys import Keys
|
||||
from .module import FheFunction, FheModule
|
||||
from .module_compiler import (
|
||||
AllComposable,
|
||||
AllInputs,
|
||||
AllOutputs,
|
||||
FunctionDef,
|
||||
Input,
|
||||
ModuleCompiler,
|
||||
NotComposable,
|
||||
Output,
|
||||
Wire,
|
||||
Wired,
|
||||
)
|
||||
from .module_compiler import FunctionDef, ModuleCompiler
|
||||
from .server import Server
|
||||
from .specs import ClientSpecs
|
||||
from .utils import inputset
|
||||
from .status import EncryptionStatus
|
||||
from .utils import get_terminal_size, inputset
|
||||
from .value import Value
|
||||
from .wiring import AllComposable, AllInputs, AllOutputs, Input, NotComposable, Output, Wire, Wired
|
||||
|
||||
@@ -10,6 +10,8 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||
|
||||
from ..representation import Graph
|
||||
from .configuration import Configuration
|
||||
from .utils import get_terminal_size
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .module import ExecutionRt
|
||||
@@ -18,6 +20,213 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
|
||||
|
||||
|
||||
class DebugManager:
|
||||
"""
|
||||
A debug manager, allowing streamlined debugging.
|
||||
"""
|
||||
|
||||
configuration: Configuration
|
||||
begin_call: Callable
|
||||
|
||||
def __init__(self, config: Configuration):
|
||||
self.configuration = config
|
||||
is_first = [True]
|
||||
|
||||
def begin_call():
|
||||
if is_first[0]:
|
||||
print()
|
||||
is_first[0] = False
|
||||
|
||||
self.begin_call = begin_call
|
||||
|
||||
def debug_table(self, title: str, activate: bool = True):
|
||||
"""
|
||||
Return a context manager that prints a table around what is printed inside the scope.
|
||||
"""
|
||||
|
||||
# pylint: disable=missing-class-docstring
|
||||
class DebugTableCm:
|
||||
def __init__(self, title):
|
||||
self.title = title
|
||||
self.columns = get_terminal_size()
|
||||
|
||||
def __enter__(self):
|
||||
print(f"{self.title}")
|
||||
print("-" * self.columns)
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _exc_tb):
|
||||
print("-" * self.columns)
|
||||
print()
|
||||
|
||||
class EmptyCm:
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _exc_tb):
|
||||
pass
|
||||
|
||||
if activate:
|
||||
self.begin_call()
|
||||
return DebugTableCm(title)
|
||||
return EmptyCm()
|
||||
|
||||
def show_graph(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing graph.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_graph
|
||||
if self.configuration.show_graph is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_bit_width_constraints(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing bitwidth constraints.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_bit_width_constraints
|
||||
if self.configuration.show_bit_width_constraints is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_bit_width_assignments(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing bitwidth assignments.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_bit_width_assignments
|
||||
if self.configuration.show_bit_width_assignments is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_assigned_graph(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing assigned graph.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_assigned_graph
|
||||
if self.configuration.show_assigned_graph is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_mlir(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing mlir.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_mlir
|
||||
if self.configuration.show_mlir is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_optimizer(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing optimizer.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_optimizer
|
||||
if self.configuration.show_optimizer is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_statistics(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing statistics.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_statistics
|
||||
if self.configuration.show_statistics is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def debug_computation_graph(self, name, function_graph):
|
||||
"""
|
||||
Print computation graph if configuration tells so.
|
||||
"""
|
||||
|
||||
if (
|
||||
self.show_graph()
|
||||
or self.show_bit_width_constraints()
|
||||
or self.show_bit_width_assignments()
|
||||
or self.show_assigned_graph()
|
||||
or self.show_mlir()
|
||||
or self.show_optimizer()
|
||||
or self.show_statistics()
|
||||
):
|
||||
if self.show_graph():
|
||||
with self.debug_table(f"Computation Graph for {name}"):
|
||||
print(function_graph.format())
|
||||
|
||||
def debug_bit_width_constaints(self, name, function_graph):
|
||||
"""
|
||||
Print bitwidth constraints if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_bit_width_constraints():
|
||||
with self.debug_table(f"Bit-Width Constraints for {name}"):
|
||||
print(function_graph.format_bit_width_constraints())
|
||||
|
||||
def debug_bit_width_assignments(self, name, function_graph):
|
||||
"""
|
||||
Print bitwidth assignments if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_bit_width_assignments():
|
||||
with self.debug_table(f"Bit-Width Assignments for {name}"):
|
||||
print(function_graph.format_bit_width_assignments())
|
||||
|
||||
def debug_assigned_graph(self, name, function_graph):
|
||||
"""
|
||||
Print assigned graphs if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_assigned_graph():
|
||||
with self.debug_table(f"Bit-Width Assigned Computation Graph for {name}"):
|
||||
print(function_graph.format(show_assigned_bit_widths=True))
|
||||
|
||||
def debug_mlir(self, mlir_str):
|
||||
"""
|
||||
Print mlir if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_mlir():
|
||||
with self.debug_table("MLIR"):
|
||||
print(mlir_str)
|
||||
|
||||
def debug_statistics(self, module):
|
||||
"""
|
||||
Print statistics if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_statistics():
|
||||
|
||||
def pretty(d, indent=0): # pragma: no cover
|
||||
if indent > 0:
|
||||
print("{")
|
||||
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict) and len(value) == 0:
|
||||
continue
|
||||
print(" " * indent + str(key) + ": ", end="")
|
||||
if isinstance(value, dict):
|
||||
pretty(value, indent + 1)
|
||||
else:
|
||||
print(value)
|
||||
if indent > 0:
|
||||
print(" " * (indent - 1) + "}")
|
||||
|
||||
with self.debug_table("Statistics"):
|
||||
pretty(module.statistics)
|
||||
|
||||
|
||||
class FunctionDebugArtifacts:
|
||||
"""
|
||||
An object containing debug artifacts for a certain function in an fhe module.
|
||||
@@ -236,88 +445,18 @@ class DebugArtifacts:
|
||||
"""
|
||||
|
||||
module_artifacts: ModuleDebugArtifacts
|
||||
_client_parameters: Optional[bytes]
|
||||
|
||||
def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
|
||||
self.module_artifacts = ModuleDebugArtifacts(["main"], output_directory)
|
||||
self._client_parameters = None
|
||||
|
||||
def add_source_code(self, function: Union[str, Callable]):
|
||||
"""
|
||||
Add source code of the function being compiled.
|
||||
|
||||
Args:
|
||||
function (Union[str, Callable]):
|
||||
either the source code of the function or the function itself
|
||||
"""
|
||||
self.module_artifacts.functions["main"].add_source_code(function)
|
||||
|
||||
def add_parameter_encryption_status(self, name: str, encryption_status: str):
|
||||
"""
|
||||
Add parameter encryption status of a parameter of the function being compiled.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the parameter
|
||||
|
||||
encryption_status (str):
|
||||
encryption status of the parameter
|
||||
"""
|
||||
|
||||
self.module_artifacts.functions["main"].add_parameter_encryption_status(
|
||||
name, encryption_status
|
||||
)
|
||||
|
||||
def add_graph(self, name: str, graph: Graph):
|
||||
"""
|
||||
Add a representation of the function being compiled.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the graph (e.g., initial, optimized, final)
|
||||
|
||||
graph (Graph):
|
||||
a representation of the function being compiled
|
||||
"""
|
||||
|
||||
self.module_artifacts.functions["main"].add_graph(name, graph)
|
||||
|
||||
def add_mlir_to_compile(self, mlir: str):
|
||||
"""
|
||||
Add textual representation of the resulting MLIR.
|
||||
|
||||
Args:
|
||||
mlir (str):
|
||||
textual representation of the resulting MLIR
|
||||
"""
|
||||
|
||||
self.module_artifacts.add_mlir_to_compile(mlir)
|
||||
|
||||
def add_client_parameters(self, client_parameters: bytes):
|
||||
"""
|
||||
Add client parameters used.
|
||||
|
||||
Args:
|
||||
client_parameters (bytes): client parameters
|
||||
"""
|
||||
|
||||
self._client_parameters = client_parameters
|
||||
self.module_artifacts = ModuleDebugArtifacts([], output_directory)
|
||||
|
||||
def export(self):
|
||||
"""
|
||||
Export the collected information to `self.output_directory`.
|
||||
"""
|
||||
|
||||
# This is a quick fix before we refactor compiler and module_compiler
|
||||
# to use the same abstraction.
|
||||
class _ModuleDebugArtifacts(ModuleDebugArtifacts):
|
||||
client_parameters = self._client_parameters
|
||||
|
||||
self.module_artifacts.__class__ = _ModuleDebugArtifacts
|
||||
self.module_artifacts.export()
|
||||
|
||||
@property
|
||||
def output_directory(self) -> Path:
|
||||
def output_directory(self) -> Path: # pragma: no cover
|
||||
"""
|
||||
Return the directory to export artifacts to.
|
||||
"""
|
||||
|
||||
@@ -5,25 +5,18 @@ Declaration of `Circuit` class.
|
||||
# pylint: disable=import-error,no-member,no-name-in-module
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import (
|
||||
CompilationContext,
|
||||
Parameter,
|
||||
SimulatedValueDecrypter,
|
||||
SimulatedValueExporter,
|
||||
)
|
||||
from concrete.compiler import CompilationContext, Parameter
|
||||
from mlir.ir import Module as MlirModule
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
from ..representation import Graph
|
||||
from .client import Client
|
||||
from .composition import CompositionRule
|
||||
from .configuration import Configuration
|
||||
from .keys import Keys
|
||||
from .module import FheFunction, FheModule
|
||||
from .server import Server
|
||||
from .utils import validate_input_args
|
||||
from .value import Value
|
||||
|
||||
# pylint: enable=import-error,no-member,no-name-in-module
|
||||
@@ -34,40 +27,20 @@ class Circuit:
|
||||
Circuit class, to combine computation graph, mlir, client and server into a single object.
|
||||
"""
|
||||
|
||||
configuration: Configuration
|
||||
_module: FheModule
|
||||
_name: str
|
||||
|
||||
graph: Graph
|
||||
mlir_module: MlirModule
|
||||
compilation_context: CompilationContext
|
||||
composition_rules: Optional[List[CompositionRule]]
|
||||
def __init__(self, module: FheModule):
|
||||
assert module.function_count == 1
|
||||
self._name = next(iter(module.functions().keys()))
|
||||
self._module = module
|
||||
|
||||
client: Client
|
||||
server: Server
|
||||
simulator: Server
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
mlir: MlirModule,
|
||||
compilation_context: CompilationContext,
|
||||
configuration: Optional[Configuration] = None,
|
||||
composition_rules: Optional[Iterable[CompositionRule]] = None,
|
||||
):
|
||||
self.configuration = configuration if configuration is not None else Configuration()
|
||||
self.composition_rules = list(composition_rules) if composition_rules else []
|
||||
|
||||
self.graph = graph
|
||||
self.mlir_module = mlir
|
||||
self.compilation_context = compilation_context
|
||||
|
||||
if self.configuration.fhe_simulation:
|
||||
self.enable_fhe_simulation()
|
||||
|
||||
if self.configuration.fhe_execution:
|
||||
self.enable_fhe_execution()
|
||||
@property
|
||||
def _function(self) -> FheFunction:
|
||||
return getattr(self._module, self._name)
|
||||
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
return self._function.graph.format()
|
||||
|
||||
def draw(
|
||||
self,
|
||||
@@ -100,7 +73,7 @@ class Circuit:
|
||||
path to the drawing
|
||||
"""
|
||||
|
||||
return self.graph.draw(horizontal=horizontal, save_to=save_to, show=show)
|
||||
return self._function.graph.draw(horizontal=horizontal, save_to=save_to, show=show)
|
||||
|
||||
@property
|
||||
def mlir(self) -> str:
|
||||
@@ -109,42 +82,19 @@ class Circuit:
|
||||
Returns:
|
||||
str: textual representation of the MLIR module
|
||||
"""
|
||||
return str(self.mlir_module).strip()
|
||||
return str(self._module.mlir_module).strip()
|
||||
|
||||
def enable_fhe_simulation(self):
|
||||
"""
|
||||
Enable FHE simulation.
|
||||
"""
|
||||
|
||||
if not hasattr(self, "simulator"):
|
||||
self.simulator = Server.create(
|
||||
self.mlir_module,
|
||||
self.configuration,
|
||||
is_simulated=True,
|
||||
compilation_context=self.compilation_context,
|
||||
composition_rules=self.composition_rules,
|
||||
)
|
||||
self._module.simulation_runtime.init()
|
||||
|
||||
def enable_fhe_execution(self):
|
||||
"""
|
||||
Enable FHE execution.
|
||||
"""
|
||||
|
||||
if not hasattr(self, "server"):
|
||||
self.server = Server.create(
|
||||
self.mlir_module,
|
||||
self.configuration,
|
||||
compilation_context=self.compilation_context,
|
||||
composition_rules=self.composition_rules,
|
||||
)
|
||||
|
||||
keyset_cache_directory = None
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
assert_that(self.configuration.enable_unsafe_features)
|
||||
assert_that(self.configuration.insecure_key_cache_location is not None)
|
||||
keyset_cache_directory = self.configuration.insecure_key_cache_location
|
||||
|
||||
self.client = Client(self.server.client_specs, keyset_cache_directory)
|
||||
self._module.execution_runtime.init() # pragma: no cover
|
||||
|
||||
def simulate(self, *args: Any) -> Any:
|
||||
"""
|
||||
@@ -158,58 +108,21 @@ class Circuit:
|
||||
Any:
|
||||
result of the simulation
|
||||
"""
|
||||
|
||||
if not hasattr(self, "simulator"): # pragma: no cover
|
||||
self.enable_fhe_simulation()
|
||||
|
||||
ordered_validated_args = validate_input_args(self.simulator.client_specs, *args)
|
||||
|
||||
exporter = SimulatedValueExporter.new(self.simulator.client_specs.client_parameters)
|
||||
exported = [
|
||||
(
|
||||
None
|
||||
if arg is None
|
||||
else Value(
|
||||
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
|
||||
if isinstance(arg, np.ndarray) and arg.shape != ()
|
||||
else exporter.export_scalar(position, int(arg))
|
||||
)
|
||||
)
|
||||
for position, arg in enumerate(ordered_validated_args)
|
||||
]
|
||||
|
||||
results = self.simulator.run(*exported)
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
|
||||
decrypter = SimulatedValueDecrypter.new(self.simulator.client_specs.client_parameters)
|
||||
decrypted = tuple(
|
||||
decrypter.decrypt(position, result.inner) for position, result in enumerate(results)
|
||||
)
|
||||
|
||||
return decrypted if len(decrypted) != 1 else decrypted[0]
|
||||
return self._function.simulate(*args)
|
||||
|
||||
@property
|
||||
def keys(self) -> Keys:
|
||||
"""
|
||||
Get the keys of the circuit.
|
||||
"""
|
||||
|
||||
if not hasattr(self, "client"): # pragma: no cover
|
||||
self.enable_fhe_execution()
|
||||
|
||||
return self.client.keys
|
||||
return self._module.keys
|
||||
|
||||
@keys.setter
|
||||
def keys(self, new_keys: Keys):
|
||||
"""
|
||||
Set the keys of the circuit.
|
||||
"""
|
||||
|
||||
if not hasattr(self, "client"): # pragma: no cover
|
||||
self.enable_fhe_execution()
|
||||
|
||||
self.client.keys = new_keys
|
||||
self._module.keys = new_keys
|
||||
|
||||
def keygen(
|
||||
self, force: bool = False, seed: Optional[int] = None, encryption_seed: Optional[int] = None
|
||||
@@ -227,11 +140,7 @@ class Circuit:
|
||||
encryption_seed (Optional[int], default = None):
|
||||
seed for encryption randomness
|
||||
"""
|
||||
|
||||
if not hasattr(self, "client"): # pragma: no cover
|
||||
self.enable_fhe_execution()
|
||||
|
||||
self.client.keygen(force, seed, encryption_seed)
|
||||
self._module.keygen(force=force, seed=seed, encryption_seed=encryption_seed)
|
||||
|
||||
def encrypt(
|
||||
self,
|
||||
@@ -248,14 +157,7 @@ class Circuit:
|
||||
Optional[Union[Value, Tuple[Optional[Value], ...]]]:
|
||||
encrypted argument(s) for evaluation
|
||||
"""
|
||||
|
||||
if self.configuration.simulate_encrypt_run_decrypt:
|
||||
return args if len(args) != 1 else args[0] # type: ignore
|
||||
|
||||
if not hasattr(self, "client"): # pragma: no cover
|
||||
self.enable_fhe_execution()
|
||||
|
||||
return self.client.encrypt(*args)
|
||||
return self._function.encrypt(*args)
|
||||
|
||||
def run(
|
||||
self,
|
||||
@@ -273,14 +175,7 @@ class Circuit:
|
||||
result(s) of evaluation
|
||||
"""
|
||||
|
||||
if self.configuration.simulate_encrypt_run_decrypt:
|
||||
return self.simulate(*args)
|
||||
|
||||
if not hasattr(self, "server"): # pragma: no cover
|
||||
self.enable_fhe_execution()
|
||||
|
||||
self.keygen(force=False)
|
||||
return self.server.run(*args, evaluation_keys=self.client.evaluation_keys)
|
||||
return self._function.run(*args)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
@@ -298,13 +193,7 @@ class Circuit:
|
||||
decrypted result(s) of evaluation
|
||||
"""
|
||||
|
||||
if self.configuration.simulate_encrypt_run_decrypt:
|
||||
return results if len(results) != 1 else results[0] # type: ignore
|
||||
|
||||
if not hasattr(self, "client"): # pragma: no cover
|
||||
self.enable_fhe_execution()
|
||||
|
||||
return self.client.decrypt(*results)
|
||||
return self._function.decrypt(*results)
|
||||
|
||||
def encrypt_run_decrypt(self, *args: Any) -> Any:
|
||||
"""
|
||||
@@ -319,101 +208,81 @@ class Circuit:
|
||||
clear result of homomorphic evaluation
|
||||
"""
|
||||
|
||||
return self.decrypt(self.run(self.encrypt(*args)))
|
||||
return self._function.encrypt_run_decrypt(*args)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Cleanup the temporary library output directory.
|
||||
"""
|
||||
|
||||
if hasattr(self, "server"): # pragma: no cover
|
||||
self.server.cleanup()
|
||||
self._module.cleanup()
|
||||
|
||||
# Properties
|
||||
|
||||
def _property(self, name: str) -> Any:
|
||||
"""
|
||||
Get a property of the circuit by name.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the property
|
||||
|
||||
Returns:
|
||||
Any:
|
||||
statistic
|
||||
"""
|
||||
|
||||
if hasattr(self, "simulator"):
|
||||
return getattr(self.simulator, name) # pragma: no cover
|
||||
|
||||
if not hasattr(self, "server"):
|
||||
self.enable_fhe_execution() # pragma: no cover
|
||||
|
||||
return getattr(self.server, name)
|
||||
|
||||
@property
|
||||
def size_of_secret_keys(self) -> int:
|
||||
"""
|
||||
Get size of the secret keys of the circuit.
|
||||
"""
|
||||
return self._property("size_of_secret_keys") # pragma: no cover
|
||||
return self._module.size_of_secret_keys # pragma: no cover
|
||||
|
||||
@property
|
||||
def size_of_bootstrap_keys(self) -> int:
|
||||
"""
|
||||
Get size of the bootstrap keys of the circuit.
|
||||
"""
|
||||
return self._property("size_of_bootstrap_keys") # pragma: no cover
|
||||
return self._module.size_of_bootstrap_keys # pragma: no cover
|
||||
|
||||
@property
|
||||
def size_of_keyswitch_keys(self) -> int:
|
||||
"""
|
||||
Get size of the key switch keys of the circuit.
|
||||
"""
|
||||
return self._property("size_of_keyswitch_keys") # pragma: no cover
|
||||
return self._module.size_of_keyswitch_keys # pragma: no cover
|
||||
|
||||
@property
|
||||
def size_of_inputs(self) -> int:
|
||||
"""
|
||||
Get size of the inputs of the circuit.
|
||||
"""
|
||||
return self._property("size_of_inputs")() # pragma: no cover
|
||||
return self._function.size_of_inputs # pragma: no cover
|
||||
|
||||
@property
|
||||
def size_of_outputs(self) -> int:
|
||||
"""
|
||||
Get size of the outputs of the circuit.
|
||||
"""
|
||||
return self._property("size_of_outputs")() # pragma: no cover
|
||||
return self._function.size_of_outputs # pragma: no cover
|
||||
|
||||
@property
|
||||
def p_error(self) -> int:
|
||||
"""
|
||||
Get probability of error for each simple TLU (on a scalar).
|
||||
"""
|
||||
return self._property("p_error") # pragma: no cover
|
||||
return self._module.p_error # pragma: no cover
|
||||
|
||||
@property
|
||||
def global_p_error(self) -> int:
|
||||
"""
|
||||
Get the probability of having at least one simple TLU error during the entire execution.
|
||||
"""
|
||||
return self._property("global_p_error") # pragma: no cover
|
||||
return self._module.p_error # pragma: no cover
|
||||
|
||||
@property
|
||||
def complexity(self) -> float:
|
||||
"""
|
||||
Get complexity of the circuit.
|
||||
"""
|
||||
return self._property("complexity") # pragma: no cover
|
||||
return self._module.complexity # pragma: no cover
|
||||
|
||||
@property
|
||||
def memory_usage_per_location(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the memory usage of operations in the circuit per location.
|
||||
"""
|
||||
return self._property("memory_usage_per_location")() # pragma: no cover
|
||||
return self._function.execution_runtime.val.server.memory_usage_per_location(
|
||||
self._name
|
||||
) # pragma: no cover
|
||||
|
||||
# Programmable Bootstrap Statistics
|
||||
|
||||
@@ -422,30 +291,28 @@ class Circuit:
|
||||
"""
|
||||
Get the number of programmable bootstraps in the circuit.
|
||||
"""
|
||||
return self._property("programmable_bootstrap_count")() # pragma: no cover
|
||||
return self._function.programmable_bootstrap_count # pragma: no cover
|
||||
|
||||
@property
|
||||
def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of programmable bootstraps per bit width in the circuit.
|
||||
"""
|
||||
return self._property("programmable_bootstrap_count_per_parameter")() # pragma: no cover
|
||||
return self._function.programmable_bootstrap_count_per_parameter # pragma: no cover
|
||||
|
||||
@property
|
||||
def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of programmable bootstraps per tag in the circuit.
|
||||
"""
|
||||
return self._property("programmable_bootstrap_count_per_tag")() # pragma: no cover
|
||||
return self._function.programmable_bootstrap_count_per_tag # pragma: no cover
|
||||
|
||||
@property
|
||||
def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[int, int]]:
|
||||
"""
|
||||
Get the number of programmable bootstraps per tag per bit width in the circuit.
|
||||
"""
|
||||
return self._property(
|
||||
"programmable_bootstrap_count_per_tag_per_parameter"
|
||||
)() # pragma: no cover
|
||||
return self._function.programmable_bootstrap_count_per_tag_per_parameter # pragma: no cover
|
||||
|
||||
# Key Switch Statistics
|
||||
|
||||
@@ -454,28 +321,28 @@ class Circuit:
|
||||
"""
|
||||
Get the number of key switches in the circuit.
|
||||
"""
|
||||
return self._property("key_switch_count")() # pragma: no cover
|
||||
return self._function.key_switch_count # pragma: no cover
|
||||
|
||||
@property
|
||||
def key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of key switches per parameter in the circuit.
|
||||
"""
|
||||
return self._property("key_switch_count_per_parameter")() # pragma: no cover
|
||||
return self._function.key_switch_count_per_parameter # pragma: no cover
|
||||
|
||||
@property
|
||||
def key_switch_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of key switches per tag in the circuit.
|
||||
"""
|
||||
return self._property("key_switch_count_per_tag")() # pragma: no cover
|
||||
return self._function.key_switch_count_per_tag # pragma: no cover
|
||||
|
||||
@property
|
||||
def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of key switches per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property("key_switch_count_per_tag_per_parameter")() # pragma: no cover
|
||||
return self._function.key_switch_count_per_tag_per_parameter # pragma: no cover
|
||||
|
||||
# Packing Key Switch Statistics
|
||||
|
||||
@@ -484,30 +351,28 @@ class Circuit:
|
||||
"""
|
||||
Get the number of packing key switches in the circuit.
|
||||
"""
|
||||
return self._property("packing_key_switch_count")() # pragma: no cover
|
||||
return self._function.packing_key_switch_count # pragma: no cover
|
||||
|
||||
@property
|
||||
def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of packing key switches per parameter in the circuit.
|
||||
"""
|
||||
return self._property("packing_key_switch_count_per_parameter")() # pragma: no cover
|
||||
return self._function.packing_key_switch_count_per_parameter # pragma: no cover
|
||||
|
||||
@property
|
||||
def packing_key_switch_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of packing key switches per tag in the circuit.
|
||||
"""
|
||||
return self._property("packing_key_switch_count_per_tag")() # pragma: no cover
|
||||
return self._function.packing_key_switch_count_per_tag # pragma: no cover
|
||||
|
||||
@property
|
||||
def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of packing key switches per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property(
|
||||
"packing_key_switch_count_per_tag_per_parameter"
|
||||
)() # pragma: no cover
|
||||
return self._function.packing_key_switch_count_per_tag_per_parameter # pragma: no cover
|
||||
|
||||
# Clear Addition Statistics
|
||||
|
||||
@@ -516,28 +381,28 @@ class Circuit:
|
||||
"""
|
||||
Get the number of clear additions in the circuit.
|
||||
"""
|
||||
return self._property("clear_addition_count")() # pragma: no cover
|
||||
return self._function.clear_addition_count # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of clear additions per parameter in the circuit.
|
||||
"""
|
||||
return self._property("clear_addition_count_per_parameter")() # pragma: no cover
|
||||
return self._function.clear_addition_count_per_parameter # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_addition_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of clear additions per tag in the circuit.
|
||||
"""
|
||||
return self._property("clear_addition_count_per_tag")() # pragma: no cover
|
||||
return self._function.clear_addition_count_per_tag # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of clear additions per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property("clear_addition_count_per_tag_per_parameter")() # pragma: no cover
|
||||
return self._function.clear_addition_count_per_tag_per_parameter # pragma: no cover
|
||||
|
||||
# Encrypted Addition Statistics
|
||||
|
||||
@@ -546,30 +411,28 @@ class Circuit:
|
||||
"""
|
||||
Get the number of encrypted additions in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_addition_count")() # pragma: no cover
|
||||
return self._function.encrypted_addition_count # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of encrypted additions per parameter in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_addition_count_per_parameter")() # pragma: no cover
|
||||
return self._function.encrypted_addition_count_per_parameter # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_addition_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of encrypted additions per tag in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_addition_count_per_tag")() # pragma: no cover
|
||||
return self._function.encrypted_addition_count_per_tag # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of encrypted additions per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property(
|
||||
"encrypted_addition_count_per_tag_per_parameter"
|
||||
)() # pragma: no cover
|
||||
return self._function.encrypted_addition_count_per_tag_per_parameter # pragma: no cover
|
||||
|
||||
# Clear Multiplication Statistics
|
||||
|
||||
@@ -578,30 +441,28 @@ class Circuit:
|
||||
"""
|
||||
Get the number of clear multiplications in the circuit.
|
||||
"""
|
||||
return self._property("clear_multiplication_count")() # pragma: no cover
|
||||
return self._function.clear_multiplication_count # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of clear multiplications per parameter in the circuit.
|
||||
"""
|
||||
return self._property("clear_multiplication_count_per_parameter")() # pragma: no cover
|
||||
return self._function.clear_multiplication_count_per_parameter # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_multiplication_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of clear multiplications per tag in the circuit.
|
||||
"""
|
||||
return self._property("clear_multiplication_count_per_tag")() # pragma: no cover
|
||||
return self._function.clear_multiplication_count_per_tag # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of clear multiplications per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property(
|
||||
"clear_multiplication_count_per_tag_per_parameter"
|
||||
)() # pragma: no cover
|
||||
return self._function.clear_multiplication_count_per_tag_per_parameter # pragma: no cover
|
||||
|
||||
# Encrypted Negation Statistics
|
||||
|
||||
@@ -610,36 +471,85 @@ class Circuit:
|
||||
"""
|
||||
Get the number of encrypted negations in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_negation_count")() # pragma: no cover
|
||||
return self._function.encrypted_negation_count # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of encrypted negations per parameter in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_negation_count_per_parameter")() # pragma: no cover
|
||||
return self._function.encrypted_negation_count_per_parameter # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_negation_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of encrypted negations per tag in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_negation_count_per_tag")() # pragma: no cover
|
||||
return self._function.encrypted_negation_count_per_tag # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of encrypted negations per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property(
|
||||
"encrypted_negation_count_per_tag_per_parameter"
|
||||
)() # pragma: no cover
|
||||
return self._function.encrypted_negation_count_per_tag_per_parameter # pragma: no cover
|
||||
|
||||
# All Statistics
|
||||
|
||||
@property
|
||||
def statistics(self) -> Dict:
|
||||
def statistics(self) -> Dict: # pragma: no cover
|
||||
"""
|
||||
Get all statistics of the circuit.
|
||||
"""
|
||||
return self._property("statistics") # pragma: no cover
|
||||
mod_stats = self._module.statistics
|
||||
func_stats = mod_stats.pop("functions")[self._name]
|
||||
return mod_stats | func_stats
|
||||
|
||||
@property
|
||||
def configuration(self) -> Configuration:
|
||||
"""
|
||||
Return the circuit configuration.
|
||||
"""
|
||||
return self._module.configuration
|
||||
|
||||
@property
|
||||
def graph(self) -> Graph:
|
||||
"""
|
||||
Return the circuit graph.
|
||||
"""
|
||||
return self._function.graph
|
||||
|
||||
@property
|
||||
def mlir_module(self) -> MlirModule:
|
||||
"""
|
||||
Return the circuit mlir module.
|
||||
"""
|
||||
return self._module.mlir_module
|
||||
|
||||
@property
|
||||
def compilation_context(self) -> CompilationContext:
|
||||
"""
|
||||
Return the circuit compilation context.
|
||||
"""
|
||||
return self._module.compilation_context
|
||||
|
||||
@property
|
||||
def client(self) -> Client:
|
||||
"""
|
||||
Return the circuit client.
|
||||
"""
|
||||
return self._module.client
|
||||
|
||||
@property
|
||||
def server(self) -> Server:
|
||||
"""
|
||||
Return the circuit server.
|
||||
"""
|
||||
return self._module.server
|
||||
|
||||
@property
|
||||
def simulator(self) -> Server:
|
||||
"""
|
||||
Return the circuit simulator.
|
||||
"""
|
||||
return self._module.simulator
|
||||
|
||||
@@ -120,7 +120,7 @@ class Client:
|
||||
def encrypt(
|
||||
self,
|
||||
*args: Optional[Union[int, np.ndarray, List]],
|
||||
function_name: str = "main",
|
||||
function_name: Optional[str] = None,
|
||||
) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]:
|
||||
"""
|
||||
Encrypt argument(s) to for evaluation.
|
||||
@@ -136,6 +136,15 @@ class Client:
|
||||
encrypted argument(s) for evaluation
|
||||
"""
|
||||
|
||||
if function_name is None:
|
||||
functions = self.specs.client_parameters.function_list()
|
||||
if len(functions) == 1:
|
||||
function_name = functions[0]
|
||||
else: # pragma: no cover
|
||||
msg = "The client contains more than one functions. \
|
||||
Provide a `function_name` keyword argument to disambiguate."
|
||||
raise TypeError(msg)
|
||||
|
||||
ordered_sanitized_args = validate_input_args(self.specs, *args, function_name=function_name)
|
||||
|
||||
self.keygen(force=False)
|
||||
@@ -160,7 +169,7 @@ class Client:
|
||||
def decrypt(
|
||||
self,
|
||||
*results: Union[Value, Tuple[Value, ...]],
|
||||
function_name: str = "main",
|
||||
function_name: Optional[str] = None,
|
||||
) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
|
||||
"""
|
||||
Decrypt result(s) of evaluation.
|
||||
@@ -176,6 +185,15 @@ class Client:
|
||||
decrypted result(s) of evaluation
|
||||
"""
|
||||
|
||||
if function_name is None:
|
||||
functions = self.specs.client_parameters.function_list()
|
||||
if len(functions) == 1:
|
||||
function_name = functions[0]
|
||||
else: # pragma: no cover
|
||||
msg = "The client contains more than one functions. \
|
||||
Provide a `function_name` keyword argument to disambiguate."
|
||||
raise TypeError(msg)
|
||||
|
||||
flattened_results: List[Value] = []
|
||||
for result in results:
|
||||
if isinstance(result, tuple): # pragma: no cover
|
||||
|
||||
@@ -4,60 +4,31 @@ Declaration of `Compiler` class.
|
||||
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
from itertools import product, repeat
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilationContext
|
||||
|
||||
from ..extensions import AutoRounder, AutoTruncator
|
||||
from ..mlir import GraphConverter
|
||||
from ..representation import Graph
|
||||
from ..tracing import Tracer
|
||||
from ..values import ValueDescription
|
||||
from .artifacts import DebugArtifacts
|
||||
from .artifacts import DebugArtifacts, FunctionDebugArtifacts, ModuleDebugArtifacts
|
||||
from .circuit import Circuit
|
||||
from .composition import CompositionClause, CompositionRule
|
||||
from .composition import CompositionPolicy
|
||||
from .configuration import Configuration
|
||||
from .utils import fuse, get_terminal_size
|
||||
from .module_compiler import FunctionDef, ModuleCompiler
|
||||
from .status import EncryptionStatus
|
||||
from .wiring import AllComposable, NotComposable, TracedOutput
|
||||
|
||||
# pylint: enable=import-error,no-name-in-module
|
||||
|
||||
|
||||
@unique
|
||||
class EncryptionStatus(str, Enum):
|
||||
"""
|
||||
EncryptionStatus enum, to represent encryption status of parameters.
|
||||
"""
|
||||
|
||||
CLEAR = "clear"
|
||||
ENCRYPTED = "encrypted"
|
||||
|
||||
|
||||
class Compiler:
|
||||
"""
|
||||
Compiler class, to glue the compilation pipeline.
|
||||
"""
|
||||
|
||||
function: Callable
|
||||
parameter_encryption_statuses: Dict[str, EncryptionStatus]
|
||||
|
||||
configuration: Configuration
|
||||
artifacts: Optional[DebugArtifacts]
|
||||
location: str
|
||||
|
||||
inputset: List[Any]
|
||||
graph: Optional[Graph]
|
||||
|
||||
compilation_context: CompilationContext
|
||||
|
||||
_is_direct: bool
|
||||
_parameter_values: Dict[str, ValueDescription]
|
||||
_module_compiler: ModuleCompiler
|
||||
_function_name: str
|
||||
|
||||
@staticmethod
|
||||
def assemble(
|
||||
@@ -97,81 +68,38 @@ class Compiler:
|
||||
name: "encrypted" if value.is_encrypted else "clear"
|
||||
for name, value in parameter_values.items()
|
||||
},
|
||||
composition=(
|
||||
AllComposable()
|
||||
if (configuration.composable if configuration is not None else False)
|
||||
else NotComposable()
|
||||
),
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
compiler._is_direct = True
|
||||
compiler._parameter_values = parameter_values
|
||||
compiler._func_def._is_direct = True
|
||||
compiler._func_def._parameter_values = parameter_values
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return compiler.compile(None, configuration, artifacts, **kwargs)
|
||||
return compiler.compile([], configuration, artifacts, **kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: Callable,
|
||||
parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]],
|
||||
composition: Optional[Union[NotComposable, AllComposable]] = None,
|
||||
):
|
||||
signature = inspect.signature(function)
|
||||
|
||||
missing_args = list(signature.parameters)
|
||||
for arg in parameter_encryption_statuses.keys():
|
||||
if arg in signature.parameters:
|
||||
missing_args.remove(arg)
|
||||
|
||||
if len(missing_args) != 0:
|
||||
parameter_str = repr(missing_args[0])
|
||||
for arg in missing_args[1:-1]:
|
||||
parameter_str += f", {repr(arg)}"
|
||||
if len(missing_args) != 1:
|
||||
parameter_str += f" and {repr(missing_args[-1])}"
|
||||
|
||||
message = (
|
||||
f"Encryption status{'es' if len(missing_args) > 1 else ''} "
|
||||
f"of parameter{'s' if len(missing_args) > 1 else ''} "
|
||||
f"{parameter_str} of function '{function.__name__}' "
|
||||
f"{'are' if len(missing_args) > 1 else 'is'} not provided"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
additional_args = list(parameter_encryption_statuses)
|
||||
for arg in signature.parameters.keys():
|
||||
if arg in parameter_encryption_statuses:
|
||||
additional_args.remove(arg)
|
||||
|
||||
if len(additional_args) != 0:
|
||||
parameter_str = repr(additional_args[0])
|
||||
for arg in additional_args[1:-1]:
|
||||
parameter_str += f", {repr(arg)}"
|
||||
if len(additional_args) != 1:
|
||||
parameter_str += f" and {repr(additional_args[-1])}"
|
||||
|
||||
message = (
|
||||
f"Encryption status{'es' if len(additional_args) > 1 else ''} "
|
||||
f"of {parameter_str} {'are' if len(additional_args) > 1 else 'is'} provided but "
|
||||
f"{'they are' if len(additional_args) > 1 else 'it is'} not a parameter "
|
||||
f"of function '{function.__name__}'"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
self.function = function # type: ignore
|
||||
self.parameter_encryption_statuses = {
|
||||
param: EncryptionStatus(status.lower())
|
||||
for param, status in parameter_encryption_statuses.items()
|
||||
}
|
||||
|
||||
self.configuration = Configuration()
|
||||
self.artifacts = None
|
||||
|
||||
self.inputset = []
|
||||
self.graph = None
|
||||
|
||||
self.compilation_context = CompilationContext.new()
|
||||
|
||||
self._is_direct = False
|
||||
self._parameter_values = {}
|
||||
self.location = (
|
||||
f"{self.function.__code__.co_filename}:{self.function.__code__.co_firstlineno}"
|
||||
if composition is None:
|
||||
composition = NotComposable()
|
||||
assert isinstance(composition, CompositionPolicy)
|
||||
func = FunctionDef(
|
||||
function=function, parameter_encryption_statuses=parameter_encryption_statuses
|
||||
)
|
||||
self._module_compiler = ModuleCompiler([func], composition)
|
||||
self._function_name = function.__name__
|
||||
|
||||
@property
|
||||
def _func_def(self) -> FunctionDef:
|
||||
return getattr(self._module_compiler, self._function_name)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -182,138 +110,10 @@ class Compiler:
|
||||
np.integer,
|
||||
np.floating,
|
||||
np.ndarray,
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
|
||||
TracedOutput,
|
||||
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray, TracedOutput], ...],
|
||||
]:
|
||||
if len(kwargs) != 0:
|
||||
message = f"Calling function '{self.function.__name__}' with kwargs is not supported"
|
||||
raise RuntimeError(message)
|
||||
|
||||
sample = args[0] if len(args) == 1 else args
|
||||
|
||||
if self.graph is None:
|
||||
self._trace(sample)
|
||||
assert self.graph is not None
|
||||
|
||||
self.inputset.append(sample)
|
||||
return self.graph(*args)
|
||||
|
||||
def _trace(self, sample: Union[Any, Tuple[Any, ...]]):
|
||||
"""
|
||||
Trace the function and fuse the resulting graph with a sample input.
|
||||
|
||||
Args:
|
||||
sample (Union[Any, Tuple[Any, ...]]):
|
||||
sample to use for tracing
|
||||
"""
|
||||
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_source_code(self.function)
|
||||
for param, encryption_status in self.parameter_encryption_statuses.items():
|
||||
self.artifacts.add_parameter_encryption_status(param, encryption_status)
|
||||
|
||||
parameters = {
|
||||
param: ValueDescription.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
|
||||
for arg, (param, status) in zip(
|
||||
(
|
||||
sample
|
||||
if len(self.parameter_encryption_statuses) > 1 or isinstance(sample, tuple)
|
||||
else (sample,)
|
||||
),
|
||||
self.parameter_encryption_statuses.items(),
|
||||
)
|
||||
}
|
||||
|
||||
self.graph = Tracer.trace(self.function, parameters, location=self.location)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("initial", self.graph)
|
||||
fuse(
|
||||
self.graph,
|
||||
(self.artifacts.module_artifacts.functions["main"] if self.artifacts else None),
|
||||
)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
action: str,
|
||||
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]],
|
||||
):
|
||||
"""
|
||||
Trace, fuse, measure bounds, and update values in the resulting graph in one go.
|
||||
|
||||
Args:
|
||||
action (str):
|
||||
action being performed (e.g., "trace", "compile")
|
||||
|
||||
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
|
||||
optional inputset to extend accumulated inputset before bounds measurement
|
||||
"""
|
||||
|
||||
if self._is_direct:
|
||||
self.graph = Tracer.trace(
|
||||
self.function, self._parameter_values, is_direct=True, location=self.location
|
||||
)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("initial", self.graph) # pragma: no cover
|
||||
|
||||
fuse(
|
||||
self.graph,
|
||||
(self.artifacts.module_artifacts.functions["main"] if self.artifacts else None),
|
||||
)
|
||||
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("final", self.graph) # pragma: no cover
|
||||
|
||||
return
|
||||
|
||||
if inputset is not None:
|
||||
previous_inputset_length = len(self.inputset)
|
||||
for index, sample in enumerate(iter(inputset)):
|
||||
self.inputset.append(sample)
|
||||
|
||||
if not isinstance(sample, tuple):
|
||||
sample = (sample,)
|
||||
|
||||
if len(sample) != len(self.parameter_encryption_statuses):
|
||||
self.inputset = self.inputset[:previous_inputset_length]
|
||||
|
||||
expected = (
|
||||
"a single value"
|
||||
if len(self.parameter_encryption_statuses) == 1
|
||||
else f"a tuple of {len(self.parameter_encryption_statuses)} values"
|
||||
)
|
||||
actual = (
|
||||
"a single value" if len(sample) == 1 else f"a tuple of {len(sample)} values"
|
||||
)
|
||||
|
||||
message = (
|
||||
f"Input #{index} of your inputset is not well formed "
|
||||
f"(expected {expected} got {actual})"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if self.configuration.auto_adjust_rounders:
|
||||
AutoRounder.adjust(self.function, self.inputset)
|
||||
|
||||
if self.configuration.auto_adjust_truncators:
|
||||
AutoTruncator.adjust(self.function, self.inputset)
|
||||
|
||||
if self.graph is None:
|
||||
try:
|
||||
first_sample = next(iter(self.inputset))
|
||||
except StopIteration as error:
|
||||
message = (
|
||||
f"{action} function '{self.function.__name__}' "
|
||||
f"without an inputset is not supported"
|
||||
)
|
||||
raise RuntimeError(message) from error
|
||||
|
||||
self._trace(first_sample)
|
||||
assert self.graph is not None
|
||||
|
||||
bounds = self.graph.measure_bounds(self.inputset)
|
||||
self.graph.update_with_bounds(bounds)
|
||||
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("final", self.graph)
|
||||
return self._func_def(*args, **kwargs)
|
||||
|
||||
def trace(
|
||||
self,
|
||||
@@ -342,78 +142,22 @@ class Compiler:
|
||||
Graph:
|
||||
computation graph representing the function prior to MLIR conversion
|
||||
"""
|
||||
|
||||
old_configuration = deepcopy(self.configuration)
|
||||
old_artifacts = deepcopy(self.artifacts)
|
||||
|
||||
if configuration is not None:
|
||||
self.configuration = configuration
|
||||
|
||||
if len(kwargs) != 0:
|
||||
self.configuration = self.configuration.fork(**kwargs)
|
||||
|
||||
self.artifacts = (
|
||||
artifacts
|
||||
art = (
|
||||
artifacts.module_artifacts.functions.get(self._function_name, FunctionDebugArtifacts())
|
||||
if artifacts is not None
|
||||
else (
|
||||
DebugArtifacts()
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures
|
||||
else None
|
||||
)
|
||||
else FunctionDebugArtifacts()
|
||||
)
|
||||
conf = (
|
||||
configuration
|
||||
if configuration is not None
|
||||
else self._module_compiler.default_configuration
|
||||
)
|
||||
if len(kwargs) != 0:
|
||||
conf = conf.fork(**kwargs)
|
||||
|
||||
try:
|
||||
self._evaluate("Tracing", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
if self.configuration.verbose or self.configuration.show_graph:
|
||||
graph = self.graph.format()
|
||||
longest_line = max(len(line) for line in graph.split("\n"))
|
||||
|
||||
try: # pragma: no cover
|
||||
# this branch cannot be covered
|
||||
# because `os.get_terminal_size()`
|
||||
# raises an exception during tests
|
||||
|
||||
columns, _ = os.get_terminal_size()
|
||||
if columns == 0: # noqa: SIM108
|
||||
columns = min(longest_line, 80)
|
||||
else:
|
||||
columns = min(longest_line, columns)
|
||||
except OSError: # pragma: no cover
|
||||
columns = min(longest_line, 80)
|
||||
|
||||
print()
|
||||
|
||||
print("Computation Graph")
|
||||
print("-" * columns)
|
||||
print(graph)
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
return self.graph
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
# this branch is reserved for unexpected issues and hence it shouldn't be tested
|
||||
# if it could be tested, we would have fixed the underlying issue
|
||||
|
||||
# if the user desires so,
|
||||
# we need to export all the information we have about the compilation
|
||||
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures:
|
||||
assert self.artifacts is not None
|
||||
self.artifacts.export()
|
||||
|
||||
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
self.configuration = old_configuration
|
||||
self.artifacts = old_artifacts
|
||||
self._func_def.evaluate("Tracing", inputset, conf, art)
|
||||
assert self._func_def.graph is not None
|
||||
return self._func_def.graph
|
||||
|
||||
# pylint: disable=too-many-branches,too-many-statements
|
||||
|
||||
@@ -445,240 +189,21 @@ class Compiler:
|
||||
compiled circuit
|
||||
"""
|
||||
|
||||
old_configuration = deepcopy(self.configuration)
|
||||
old_artifacts = deepcopy(self.artifacts)
|
||||
|
||||
if configuration is not None:
|
||||
self.configuration = configuration
|
||||
|
||||
if len(kwargs) != 0:
|
||||
self.configuration = self.configuration.fork(**kwargs)
|
||||
|
||||
self.artifacts = (
|
||||
artifacts
|
||||
if artifacts is not None
|
||||
else (
|
||||
DebugArtifacts()
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures
|
||||
else None
|
||||
)
|
||||
art = artifacts.module_artifacts if artifacts is not None else ModuleDebugArtifacts()
|
||||
conf = (
|
||||
configuration
|
||||
if configuration is not None
|
||||
else self._module_compiler.default_configuration
|
||||
)
|
||||
if len(kwargs) != 0:
|
||||
conf = conf.fork(**kwargs)
|
||||
|
||||
try:
|
||||
self._evaluate("Compiling", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
show_graph = (
|
||||
self.configuration.show_graph
|
||||
if self.configuration.show_graph is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
show_bit_width_constraints = (
|
||||
self.configuration.show_bit_width_constraints
|
||||
if self.configuration.show_bit_width_constraints is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
show_bit_width_assignments = (
|
||||
self.configuration.show_bit_width_assignments
|
||||
if self.configuration.show_bit_width_assignments is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
show_assigned_graph = (
|
||||
self.configuration.show_assigned_graph
|
||||
if self.configuration.show_assigned_graph is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
show_mlir = (
|
||||
self.configuration.show_mlir
|
||||
if self.configuration.show_mlir is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
show_optimizer = (
|
||||
self.configuration.show_optimizer
|
||||
if self.configuration.show_optimizer is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
show_statistics = (
|
||||
self.configuration.show_statistics
|
||||
if self.configuration.show_statistics is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
columns = get_terminal_size()
|
||||
is_first = True
|
||||
|
||||
if (
|
||||
show_graph
|
||||
or show_bit_width_constraints
|
||||
or show_bit_width_assignments
|
||||
or show_assigned_graph
|
||||
or show_mlir
|
||||
or show_optimizer
|
||||
or show_statistics
|
||||
):
|
||||
if show_graph:
|
||||
if is_first: # pragma: no cover
|
||||
print()
|
||||
is_first = False
|
||||
|
||||
print("Computation Graph")
|
||||
print("-" * columns)
|
||||
print(self.graph.format())
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
# We generate the composition rules if needed:
|
||||
composition_rules = []
|
||||
if self.configuration.composable:
|
||||
compo_froms = map(
|
||||
CompositionClause.create,
|
||||
zip(repeat(self.graph.name), range(len(self.graph.output_nodes))),
|
||||
)
|
||||
compo_tos = map(
|
||||
CompositionClause.create,
|
||||
zip(repeat(self.graph.name), range(len(self.graph.input_nodes))),
|
||||
)
|
||||
composition_rules = list(
|
||||
map(CompositionRule.create, product(compo_froms, compo_tos))
|
||||
)
|
||||
|
||||
# in-memory MLIR module
|
||||
mlir_context = self.compilation_context.mlir_context()
|
||||
mlir_module = GraphConverter(self.configuration, composition_rules).convert(
|
||||
self.graph, mlir_context
|
||||
)
|
||||
# textual representation of the MLIR module
|
||||
mlir_str = str(mlir_module).strip()
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_mlir_to_compile(mlir_str)
|
||||
|
||||
if show_bit_width_constraints:
|
||||
if is_first: # pragma: no cover
|
||||
print()
|
||||
is_first = False
|
||||
|
||||
print("Bit-Width Constraints")
|
||||
print("-" * columns)
|
||||
print(self.graph.format_bit_width_constraints())
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
if show_bit_width_assignments:
|
||||
if is_first: # pragma: no cover
|
||||
print()
|
||||
is_first = False
|
||||
|
||||
print("Bit-Width Assignments")
|
||||
print("-" * columns)
|
||||
print(self.graph.format_bit_width_assignments())
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
if show_assigned_graph:
|
||||
if is_first: # pragma: no cover
|
||||
print()
|
||||
is_first = False
|
||||
|
||||
print("Bit-Width Assigned Computation Graph")
|
||||
print("-" * columns)
|
||||
print(self.graph.format(show_assigned_bit_widths=True))
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
if show_mlir:
|
||||
if is_first: # pragma: no cover
|
||||
print()
|
||||
is_first = False
|
||||
|
||||
print("MLIR")
|
||||
print("-" * columns)
|
||||
print(mlir_str)
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
if show_optimizer:
|
||||
if is_first: # pragma: no cover
|
||||
print()
|
||||
is_first = False
|
||||
|
||||
print("Optimizer")
|
||||
print("-" * columns)
|
||||
|
||||
circuit = Circuit(
|
||||
self.graph,
|
||||
mlir_module,
|
||||
self.compilation_context,
|
||||
self.configuration,
|
||||
composition_rules,
|
||||
)
|
||||
|
||||
if hasattr(circuit, "client"):
|
||||
client_parameters = circuit.client.specs.client_parameters
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_client_parameters(client_parameters.serialize())
|
||||
|
||||
if show_optimizer:
|
||||
print("-" * columns)
|
||||
print()
|
||||
|
||||
if show_statistics:
|
||||
if is_first: # pragma: no cover
|
||||
print()
|
||||
|
||||
print("Statistics")
|
||||
print("-" * columns)
|
||||
|
||||
def pretty(d, indent=0): # pragma: no cover
|
||||
if indent > 0:
|
||||
print("{")
|
||||
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict) and len(value) == 0:
|
||||
continue
|
||||
|
||||
print(" " * indent + str(key) + ": ", end="")
|
||||
|
||||
if isinstance(value, dict):
|
||||
pretty(value, indent + 1)
|
||||
else:
|
||||
print(value)
|
||||
|
||||
if indent > 0:
|
||||
print(" " * (indent - 1) + "}")
|
||||
|
||||
pretty(circuit.statistics)
|
||||
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
# this branch is reserved for unexpected issues and hence it shouldn't be tested
|
||||
# if it could be tested, we would have fixed the underlying issue
|
||||
|
||||
# if the user desires so,
|
||||
# we need to export all the information we have about the compilation
|
||||
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures:
|
||||
assert self.artifacts is not None
|
||||
self.artifacts.export()
|
||||
|
||||
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
||||
with open(traceback_path, "w", encoding="utf-8") as f:
|
||||
f.write(traceback.format_exc())
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
self.configuration = old_configuration
|
||||
self.artifacts = old_artifacts
|
||||
|
||||
return circuit
|
||||
if conf.composable:
|
||||
self._module_compiler.composition = AllComposable()
|
||||
fhe_module = self._module_compiler.compile(
|
||||
{self._function_name: inputset}, configuration=conf, module_artifacts=art
|
||||
)
|
||||
return Circuit(fhe_module)
|
||||
|
||||
# pylint: enable=too-many-branches,too-many-statements
|
||||
|
||||
@@ -686,5 +211,6 @@ class Compiler:
|
||||
"""
|
||||
Reset the compiler so that another compilation with another inputset can be performed.
|
||||
"""
|
||||
fresh_compiler = Compiler(self.function, self.parameter_encryption_statuses)
|
||||
fdef = self._module_compiler.functions[self._function_name]
|
||||
fresh_compiler = Compiler(fdef.function, fdef.parameter_encryption_statuses)
|
||||
self.__dict__.update(fresh_compiler.__dict__)
|
||||
|
||||
@@ -12,9 +12,11 @@ from ..tracing.typing import ScalarAnnotation
|
||||
from ..values import ValueDescription
|
||||
from .artifacts import DebugArtifacts
|
||||
from .circuit import Circuit
|
||||
from .compiler import Compiler, EncryptionStatus
|
||||
from .compiler import Compiler
|
||||
from .configuration import Configuration
|
||||
from .module_compiler import AllComposable, CompositionPolicy, FunctionDef, ModuleCompiler
|
||||
from .module_compiler import CompositionPolicy, FunctionDef, ModuleCompiler
|
||||
from .status import EncryptionStatus
|
||||
from .wiring import AllComposable
|
||||
|
||||
|
||||
def circuit(
|
||||
@@ -151,7 +153,9 @@ class Compilable:
|
||||
compiled circuit
|
||||
"""
|
||||
|
||||
return self.compiler.compile(inputset, configuration, artifacts, **kwargs)
|
||||
return self.compiler.compile(
|
||||
inputset if inputset is not None else [], configuration, artifacts, **kwargs
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
||||
@@ -166,12 +166,10 @@ class FheFunction:
|
||||
decrypted = tuple(
|
||||
decrypter.decrypt(position, result.inner) for position, result in enumerate(results)
|
||||
)
|
||||
|
||||
return decrypted if len(decrypted) != 1 else decrypted[0]
|
||||
|
||||
def encrypt(
|
||||
self,
|
||||
*args: Optional[Union[int, np.ndarray, List]],
|
||||
self, *args: Optional[Union[int, np.ndarray, List]]
|
||||
) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]:
|
||||
"""
|
||||
Encrypt argument(s) to for evaluation.
|
||||
@@ -693,6 +691,7 @@ class FheModule:
|
||||
"""
|
||||
Get size of the key switch keys of the module.
|
||||
"""
|
||||
|
||||
return self.execution_runtime.val.server.size_of_keyswitch_keys # pragma: no cover
|
||||
|
||||
@property
|
||||
@@ -758,12 +757,26 @@ class FheModule:
|
||||
return self.execution_runtime.val.server
|
||||
|
||||
@property
|
||||
def client(self) -> Optional[Client]:
|
||||
def client(self) -> Client:
|
||||
"""
|
||||
Returns the execution client object tied to the module.
|
||||
"""
|
||||
return self.execution_runtime.val.client
|
||||
|
||||
@property
|
||||
def simulator(self) -> Server:
|
||||
"""
|
||||
Returns the simulation server object tied to the module.
|
||||
"""
|
||||
return self.simulation_runtime.val.server
|
||||
|
||||
@property
|
||||
def function_count(self) -> int:
|
||||
"""
|
||||
Returns the number of functions in the module.
|
||||
"""
|
||||
return len(self.graphs)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item not in list(self.graphs.keys()):
|
||||
error = f"No attribute {item}"
|
||||
|
||||
@@ -6,23 +6,8 @@ Declaration of `MultiCompiler` class.
|
||||
|
||||
import inspect
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from itertools import chain, product, repeat
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilationContext
|
||||
@@ -32,12 +17,13 @@ from ..mlir import GraphConverter
|
||||
from ..representation import Graph
|
||||
from ..tracing import Tracer
|
||||
from ..values import ValueDescription
|
||||
from .artifacts import FunctionDebugArtifacts, ModuleDebugArtifacts
|
||||
from .compiler import EncryptionStatus
|
||||
from .composition import CompositionClause, CompositionPolicy, CompositionRule
|
||||
from .artifacts import DebugManager, FunctionDebugArtifacts, ModuleDebugArtifacts
|
||||
from .composition import CompositionPolicy
|
||||
from .configuration import Configuration
|
||||
from .module import FheModule
|
||||
from .utils import fuse, get_terminal_size
|
||||
from .status import EncryptionStatus
|
||||
from .utils import fuse
|
||||
from .wiring import Input, Output, TracedOutput, Wire, Wired, WireTracingContextManager
|
||||
|
||||
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
|
||||
|
||||
@@ -49,13 +35,14 @@ class FunctionDef:
|
||||
An object representing the definition of a function as used in an fhe module.
|
||||
"""
|
||||
|
||||
name: str
|
||||
function: Callable
|
||||
parameter_encryption_statuses: Dict[str, EncryptionStatus]
|
||||
inputset: List[Any]
|
||||
graph: Optional[Graph]
|
||||
_parameter_values: Dict[str, ValueDescription]
|
||||
location: str
|
||||
|
||||
_is_direct: bool
|
||||
_parameter_values: Dict[str, ValueDescription]
|
||||
_trace_wires: Optional[Set["Wire"]]
|
||||
|
||||
def __init__(
|
||||
@@ -112,13 +99,18 @@ class FunctionDef:
|
||||
}
|
||||
self.inputset = []
|
||||
self.graph = None
|
||||
self.name = function.__name__
|
||||
self._is_direct = False
|
||||
self._parameter_values = {}
|
||||
self.location = (
|
||||
f"{self.function.__code__.co_filename}:{self.function.__code__.co_firstlineno}"
|
||||
)
|
||||
self._trace_wires = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the name of the function."""
|
||||
return self.function.__name__
|
||||
|
||||
def trace(
|
||||
self,
|
||||
sample: Union[Any, Tuple[Any, ...]],
|
||||
@@ -151,7 +143,7 @@ class FunctionDef:
|
||||
)
|
||||
}
|
||||
|
||||
self.graph = Tracer.trace(self.function, parameters, name=self.name, location=self.location)
|
||||
self.graph = Tracer.trace(self.function, parameters, location=self.location)
|
||||
if artifacts is not None:
|
||||
artifacts.add_graph("initial", self.graph)
|
||||
|
||||
@@ -181,6 +173,21 @@ class FunctionDef:
|
||||
artifact object to store informations in
|
||||
"""
|
||||
|
||||
if self._is_direct:
|
||||
self.graph = Tracer.trace(
|
||||
self.function,
|
||||
self._parameter_values,
|
||||
is_direct=True,
|
||||
location=self.location,
|
||||
)
|
||||
artifacts.add_graph("initial", self.graph) # pragma: no cover
|
||||
fuse(
|
||||
self.graph,
|
||||
artifacts,
|
||||
)
|
||||
artifacts.add_graph("final", self.graph) # pragma: no cover
|
||||
return
|
||||
|
||||
if inputset is not None:
|
||||
previous_inputset_length = len(self.inputset)
|
||||
for index, sample in enumerate(iter(inputset)):
|
||||
@@ -320,424 +327,6 @@ class FunctionDef:
|
||||
return output[0] if len(output) == 1 else output
|
||||
|
||||
|
||||
class NotComposable:
|
||||
"""
|
||||
Composition policy that does not allow the forwarding of any output to any input.
|
||||
"""
|
||||
|
||||
def get_rules_iter(self, _) -> Iterable[CompositionRule]:
|
||||
"""
|
||||
Return an iterator over composition rules.
|
||||
"""
|
||||
return [] # pragma: no cover
|
||||
|
||||
|
||||
class AllComposable:
|
||||
"""
|
||||
Composition policy that allows to forward any output of the module to any of its input.
|
||||
"""
|
||||
|
||||
def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
|
||||
"""
|
||||
Return an iterator over composition rules.
|
||||
"""
|
||||
outputs = []
|
||||
for f in funcs:
|
||||
for pos, node in f.output_nodes.items():
|
||||
if node.output.is_encrypted:
|
||||
outputs.append(CompositionClause.create((f.name, pos)))
|
||||
inputs = []
|
||||
for f in funcs:
|
||||
for pos, node in f.input_nodes.items():
|
||||
if node.output.is_encrypted:
|
||||
inputs.append(CompositionClause.create((f.name, pos)))
|
||||
|
||||
return map(CompositionRule.create, product(outputs, inputs))
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class WireOutput(Protocol):
|
||||
"""
|
||||
A protocol for wire outputs.
|
||||
"""
|
||||
|
||||
def get_outputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible outputs of the wire output.
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class WireInput(Protocol):
|
||||
"""
|
||||
A protocol for wire inputs.
|
||||
"""
|
||||
|
||||
def get_inputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible inputs of the wire input.
|
||||
"""
|
||||
|
||||
|
||||
class Output(NamedTuple):
|
||||
"""
|
||||
The output of a given function of a module.
|
||||
"""
|
||||
|
||||
func: FunctionDef
|
||||
pos: int
|
||||
|
||||
def get_outputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible outputs of the wire output.
|
||||
"""
|
||||
return [CompositionClause(self.func.name, self.pos)]
|
||||
|
||||
|
||||
class AllOutputs(NamedTuple):
|
||||
"""
|
||||
All the encrypted outputs of a given function of a module.
|
||||
"""
|
||||
|
||||
func: FunctionDef
|
||||
|
||||
def get_outputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible outputs of the wire output.
|
||||
"""
|
||||
assert self.func.graph # pragma: no cover
|
||||
# No need to filter since only encrypted outputs are valid.
|
||||
return map( # pragma: no cover
|
||||
CompositionClause.create,
|
||||
zip(repeat(self.func.name), range(self.func.graph.outputs_count)),
|
||||
)
|
||||
|
||||
|
||||
class Input(NamedTuple):
|
||||
"""
|
||||
The input of a given function of a module.
|
||||
"""
|
||||
|
||||
func: FunctionDef
|
||||
pos: int
|
||||
|
||||
def get_inputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible inputs of the wire input.
|
||||
"""
|
||||
return [CompositionClause(self.func.name, self.pos)]
|
||||
|
||||
|
||||
class AllInputs(NamedTuple):
|
||||
"""
|
||||
All the encrypted inputs of a given function of a module.
|
||||
"""
|
||||
|
||||
func: FunctionDef
|
||||
|
||||
def get_inputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible inputs of the wire input.
|
||||
"""
|
||||
assert self.func.graph # pragma: no cover
|
||||
output = []
|
||||
for i in range(self.func.graph.inputs_count):
|
||||
if self.func.graph.input_nodes[i].output.is_encrypted:
|
||||
output.append(CompositionClause.create((self.func.name, i)))
|
||||
return output
|
||||
|
||||
|
||||
class Wire(NamedTuple):
|
||||
"""
|
||||
A forwarding rule between an output and an input.
|
||||
"""
|
||||
|
||||
output: WireOutput
|
||||
input: WireInput
|
||||
|
||||
def get_rules_iter(self, _) -> Iterable[CompositionRule]:
|
||||
"""
|
||||
Return an iterator over composition rules.
|
||||
"""
|
||||
return map(
|
||||
CompositionRule.create,
|
||||
product(self.output.get_outputs_iter(), self.input.get_inputs_iter()),
|
||||
)
|
||||
|
||||
|
||||
class Wired:
|
||||
"""
|
||||
Composition policy which allows the forwarding of certain outputs to certain inputs.
|
||||
"""
|
||||
|
||||
wires: Set[Wire]
|
||||
|
||||
def __init__(self, wires: Optional[Set[Wire]] = None):
|
||||
self.wires = wires if wires else set()
|
||||
|
||||
def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
|
||||
"""
|
||||
Return an iterator over composition rules.
|
||||
"""
|
||||
funcsd = {f.name: f for f in funcs}
|
||||
rules = list(chain(*[w.get_rules_iter(funcs) for w in self.wires]))
|
||||
|
||||
# We check that the given rules are legit (they concern only encrypted values)
|
||||
for rule in rules:
|
||||
if (
|
||||
not funcsd[rule.from_.func].output_nodes[rule.from_.pos].output.is_encrypted
|
||||
): # pragma: no cover
|
||||
message = f"Invalid composition rule encountered: \
|
||||
Output {rule.from_.pos} of {rule.from_.func} is not encrypted"
|
||||
raise RuntimeError(message)
|
||||
if not funcsd[rule.to.func].input_nodes[rule.to.pos].output.is_encrypted:
|
||||
message = f"Invalid composition rule encountered: \
|
||||
Input {rule.from_.pos} of {rule.from_.func} is not encrypted"
|
||||
raise RuntimeError(message)
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
class DebugManager:
|
||||
"""
|
||||
A debug manager, allowing streamlined debugging.
|
||||
"""
|
||||
|
||||
configuration: Configuration
|
||||
begin_call: Callable
|
||||
|
||||
def __init__(self, config: Configuration):
|
||||
self.configuration = config
|
||||
is_first = [True]
|
||||
|
||||
def begin_call():
|
||||
if is_first[0]:
|
||||
print()
|
||||
is_first[0] = False
|
||||
|
||||
self.begin_call = begin_call
|
||||
|
||||
def debug_table(self, title: str, activate: bool = True):
|
||||
"""
|
||||
Return a context manager that prints a table around what is printed inside the scope.
|
||||
"""
|
||||
|
||||
# pylint: disable=missing-class-docstring
|
||||
class DebugTableCm:
|
||||
def __init__(self, title):
|
||||
self.title = title
|
||||
self.columns = get_terminal_size()
|
||||
|
||||
def __enter__(self):
|
||||
print(f"{self.title}")
|
||||
print("-" * self.columns)
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _exc_tb):
|
||||
print("-" * self.columns)
|
||||
print()
|
||||
|
||||
class EmptyCm:
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _exc_tb):
|
||||
pass
|
||||
|
||||
if activate:
|
||||
self.begin_call()
|
||||
return DebugTableCm(title)
|
||||
return EmptyCm()
|
||||
|
||||
def show_graph(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing graph.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_graph
|
||||
if self.configuration.show_graph is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_bit_width_constraints(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing bitwidth constraints.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_bit_width_constraints
|
||||
if self.configuration.show_bit_width_constraints is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_bit_width_assignments(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing bitwidth assignments.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_bit_width_assignments
|
||||
if self.configuration.show_bit_width_assignments is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_assigned_graph(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing assigned graph.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_assigned_graph
|
||||
if self.configuration.show_assigned_graph is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_mlir(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing mlir.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_mlir
|
||||
if self.configuration.show_mlir is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_optimizer(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing optimizer.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_optimizer
|
||||
if self.configuration.show_optimizer is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def show_statistics(self) -> bool:
|
||||
"""
|
||||
Tell if the configuration involves showing statistics.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.configuration.show_statistics
|
||||
if self.configuration.show_statistics is not None
|
||||
else self.configuration.verbose
|
||||
)
|
||||
|
||||
def debug_computation_graph(self, name, function_graph):
|
||||
"""
|
||||
Print computation graph if configuration tells so.
|
||||
"""
|
||||
|
||||
if (
|
||||
self.show_graph()
|
||||
or self.show_bit_width_constraints()
|
||||
or self.show_bit_width_assignments()
|
||||
or self.show_assigned_graph()
|
||||
or self.show_mlir()
|
||||
or self.show_optimizer()
|
||||
or self.show_statistics()
|
||||
):
|
||||
if self.show_graph():
|
||||
with self.debug_table(f"Computation Graph for {name}"):
|
||||
print(function_graph.format())
|
||||
|
||||
def debug_bit_width_constaints(self, name, function_graph):
|
||||
"""
|
||||
Print bitwidth constraints if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_bit_width_constraints():
|
||||
with self.debug_table(f"Bit-Width Constraints for {name}"):
|
||||
print(function_graph.format_bit_width_constraints())
|
||||
|
||||
def debug_bit_width_assignments(self, name, function_graph):
|
||||
"""
|
||||
Print bitwidth assignments if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_bit_width_assignments():
|
||||
with self.debug_table(f"Bit-Width Assignments for {name}"):
|
||||
print(function_graph.format_bit_width_assignments())
|
||||
|
||||
def debug_assigned_graph(self, name, function_graph):
|
||||
"""
|
||||
Print assigned graphs if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_assigned_graph():
|
||||
with self.debug_table(f"Bit-Width Assigned Computation Graph for {name}"):
|
||||
print(function_graph.format(show_assigned_bit_widths=True))
|
||||
|
||||
def debug_mlir(self, mlir_str):
|
||||
"""
|
||||
Print mlir if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_mlir():
|
||||
with self.debug_table("MLIR"):
|
||||
print(mlir_str)
|
||||
|
||||
def debug_statistics(self, module):
|
||||
"""
|
||||
Print statistics if configuration tells so.
|
||||
"""
|
||||
|
||||
if self.show_statistics():
|
||||
|
||||
def pretty(d, indent=0): # pragma: no cover
|
||||
if indent > 0:
|
||||
print("{")
|
||||
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict) and len(value) == 0:
|
||||
continue
|
||||
print(" " * indent + str(key) + ": ", end="")
|
||||
if isinstance(value, dict):
|
||||
pretty(value, indent + 1)
|
||||
else:
|
||||
print(value)
|
||||
if indent > 0:
|
||||
print(" " * (indent - 1) + "}")
|
||||
|
||||
with self.debug_table("Statistics"):
|
||||
pretty(module.statistics)
|
||||
|
||||
|
||||
class TracedOutput(NamedTuple):
|
||||
"""
|
||||
A wrapper type used to trace wiring.
|
||||
|
||||
Allows to tag an output value coming from an other module function, and binds it with
|
||||
information about its origin.
|
||||
"""
|
||||
|
||||
output_info: Output
|
||||
returned_value: Any
|
||||
|
||||
|
||||
class WireTracingContextManager:
|
||||
"""
|
||||
A context manager returned by the `wire_pipeline` method.
|
||||
|
||||
Activates wire tracing and yields an inputset that can be iterated on for tracing.
|
||||
"""
|
||||
|
||||
def __init__(self, module, inputset):
|
||||
self.module = module
|
||||
self.inputset = inputset
|
||||
|
||||
def __enter__(self):
|
||||
for func in self.module.functions.values():
|
||||
func._trace_wires = self.module.composition.wires
|
||||
return self.inputset
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _exc_tb):
|
||||
for func in self.module.functions.values():
|
||||
func._trace_wires = None
|
||||
|
||||
|
||||
class ModuleCompiler:
|
||||
"""
|
||||
Compiler class for multiple functions, to glue the compilation pipeline.
|
||||
@@ -766,7 +355,9 @@ class ModuleCompiler:
|
||||
|
||||
def compile(
|
||||
self,
|
||||
inputsets: Optional[Dict[str, Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]] = None,
|
||||
inputsets: Optional[
|
||||
Dict[str, Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]]
|
||||
] = None,
|
||||
configuration: Optional[Configuration] = None,
|
||||
module_artifacts: Optional[ModuleDebugArtifacts] = None,
|
||||
**kwargs,
|
||||
@@ -793,7 +384,6 @@ class ModuleCompiler:
|
||||
"""
|
||||
|
||||
configuration = configuration if configuration is not None else self.default_configuration
|
||||
configuration = deepcopy(configuration)
|
||||
if len(kwargs) != 0:
|
||||
configuration = configuration.fork(**kwargs)
|
||||
|
||||
|
||||
@@ -401,7 +401,7 @@ class Server:
|
||||
self,
|
||||
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
|
||||
evaluation_keys: Optional[EvaluationKeys] = None,
|
||||
function_name: str = "main",
|
||||
function_name: Optional[str] = None,
|
||||
) -> Union[Value, Tuple[Value, ...]]:
|
||||
"""
|
||||
Evaluate.
|
||||
@@ -421,6 +421,15 @@ class Server:
|
||||
result(s) of evaluation
|
||||
"""
|
||||
|
||||
if function_name is None:
|
||||
functions = self.client_specs.client_parameters.function_list()
|
||||
if len(functions) == 1:
|
||||
function_name = functions[0]
|
||||
else: # pragma: no cover
|
||||
msg = "The client contains more than one functions. \
|
||||
Provide a `function_name` keyword argument to disambiguate."
|
||||
raise TypeError(msg)
|
||||
|
||||
if evaluation_keys is None and not self.is_simulated:
|
||||
message = "Expected evaluation keys to be provided when not in simulation mode"
|
||||
raise RuntimeError(message)
|
||||
@@ -527,19 +536,19 @@ class Server:
|
||||
"""
|
||||
return self._compilation_feedback.complexity
|
||||
|
||||
def memory_usage_per_location(self, function: str = "main") -> Dict[str, int]:
|
||||
def memory_usage_per_location(self, function: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get the memory usage of operations per location.
|
||||
"""
|
||||
return self._compilation_feedback.circuit(function).memory_usage_per_location
|
||||
|
||||
def size_of_inputs(self, function: str = "main") -> int:
|
||||
def size_of_inputs(self, function: str) -> int:
|
||||
"""
|
||||
Get size of the inputs of the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.circuit(function).total_inputs_size
|
||||
|
||||
def size_of_outputs(self, function: str = "main") -> int:
|
||||
def size_of_outputs(self, function: str) -> int:
|
||||
"""
|
||||
Get size of the outputs of the compiled program.
|
||||
"""
|
||||
@@ -547,7 +556,7 @@ class Server:
|
||||
|
||||
# Programmable Bootstrap Statistics
|
||||
|
||||
def programmable_bootstrap_count(self, function: str = "main") -> int:
|
||||
def programmable_bootstrap_count(self, function: str) -> int:
|
||||
"""
|
||||
Get the number of programmable bootstraps in the compiled program.
|
||||
"""
|
||||
@@ -555,9 +564,7 @@ class Server:
|
||||
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
|
||||
)
|
||||
|
||||
def programmable_bootstrap_count_per_parameter(
|
||||
self, function: str = "main"
|
||||
) -> Dict[Parameter, int]:
|
||||
def programmable_bootstrap_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of programmable bootstraps per parameter in the compiled program.
|
||||
"""
|
||||
@@ -567,7 +574,7 @@ class Server:
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
def programmable_bootstrap_count_per_tag(self, function: str = "main") -> Dict[str, int]:
|
||||
def programmable_bootstrap_count_per_tag(self, function: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of programmable bootstraps per tag in the compiled program.
|
||||
"""
|
||||
@@ -576,7 +583,7 @@ class Server:
|
||||
)
|
||||
|
||||
def programmable_bootstrap_count_per_tag_per_parameter(
|
||||
self, function: str = "main"
|
||||
self, function: str
|
||||
) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of programmable bootstraps per tag per parameter in the compiled program.
|
||||
@@ -589,7 +596,7 @@ class Server:
|
||||
|
||||
# Key Switch Statistics
|
||||
|
||||
def key_switch_count(self, function: str = "main") -> int:
|
||||
def key_switch_count(self, function: str) -> int:
|
||||
"""
|
||||
Get the number of key switches in the compiled program.
|
||||
"""
|
||||
@@ -597,7 +604,7 @@ class Server:
|
||||
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
|
||||
)
|
||||
|
||||
def key_switch_count_per_parameter(self, function: str = "main") -> Dict[Parameter, int]:
|
||||
def key_switch_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of key switches per parameter in the compiled program.
|
||||
"""
|
||||
@@ -607,7 +614,7 @@ class Server:
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
def key_switch_count_per_tag(self, function: str = "main") -> Dict[str, int]:
|
||||
def key_switch_count_per_tag(self, function: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of key switches per tag in the compiled program.
|
||||
"""
|
||||
@@ -616,7 +623,7 @@ class Server:
|
||||
)
|
||||
|
||||
def key_switch_count_per_tag_per_parameter(
|
||||
self, function: str = "main"
|
||||
self, function: str
|
||||
) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of key switches per tag per parameter in the compiled program.
|
||||
@@ -629,7 +636,7 @@ class Server:
|
||||
|
||||
# Packing Key Switch Statistics
|
||||
|
||||
def packing_key_switch_count(self, function: str = "main") -> int:
|
||||
def packing_key_switch_count(self, function: str) -> int:
|
||||
"""
|
||||
Get the number of packing key switches in the compiled program.
|
||||
"""
|
||||
@@ -637,9 +644,7 @@ class Server:
|
||||
operations={PrimitiveOperation.WOP_PBS}
|
||||
)
|
||||
|
||||
def packing_key_switch_count_per_parameter(
|
||||
self, function: str = "main"
|
||||
) -> Dict[Parameter, int]:
|
||||
def packing_key_switch_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of packing key switches per parameter in the compiled program.
|
||||
"""
|
||||
@@ -649,7 +654,7 @@ class Server:
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
def packing_key_switch_count_per_tag(self, function: str = "main") -> Dict[str, int]:
|
||||
def packing_key_switch_count_per_tag(self, function: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of packing key switches per tag in the compiled program.
|
||||
"""
|
||||
@@ -658,7 +663,7 @@ class Server:
|
||||
)
|
||||
|
||||
def packing_key_switch_count_per_tag_per_parameter(
|
||||
self, function: str = "main"
|
||||
self, function: str
|
||||
) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of packing key switches per tag per parameter in the compiled program.
|
||||
@@ -671,7 +676,7 @@ class Server:
|
||||
|
||||
# Clear Addition Statistics
|
||||
|
||||
def clear_addition_count(self, function: str = "main") -> int:
|
||||
def clear_addition_count(self, function: str) -> int:
|
||||
"""
|
||||
Get the number of clear additions in the compiled program.
|
||||
"""
|
||||
@@ -679,7 +684,7 @@ class Server:
|
||||
operations={PrimitiveOperation.CLEAR_ADDITION}
|
||||
)
|
||||
|
||||
def clear_addition_count_per_parameter(self, function: str = "main") -> Dict[Parameter, int]:
|
||||
def clear_addition_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of clear additions per parameter in the compiled program.
|
||||
"""
|
||||
@@ -689,7 +694,7 @@ class Server:
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
def clear_addition_count_per_tag(self, function: str = "main") -> Dict[str, int]:
|
||||
def clear_addition_count_per_tag(self, function: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of clear additions per tag in the compiled program.
|
||||
"""
|
||||
@@ -698,7 +703,7 @@ class Server:
|
||||
)
|
||||
|
||||
def clear_addition_count_per_tag_per_parameter(
|
||||
self, function: str = "main"
|
||||
self, function: str
|
||||
) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of clear additions per tag per parameter in the compiled program.
|
||||
@@ -711,7 +716,7 @@ class Server:
|
||||
|
||||
# Encrypted Addition Statistics
|
||||
|
||||
def encrypted_addition_count(self, function: str = "main") -> int:
|
||||
def encrypted_addition_count(self, function: str) -> int:
|
||||
"""
|
||||
Get the number of encrypted additions in the compiled program.
|
||||
"""
|
||||
@@ -719,9 +724,7 @@ class Server:
|
||||
operations={PrimitiveOperation.ENCRYPTED_ADDITION}
|
||||
)
|
||||
|
||||
def encrypted_addition_count_per_parameter(
|
||||
self, function: str = "main"
|
||||
) -> Dict[Parameter, int]:
|
||||
def encrypted_addition_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of encrypted additions per parameter in the compiled program.
|
||||
"""
|
||||
@@ -731,7 +734,7 @@ class Server:
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
def encrypted_addition_count_per_tag(self, function: str = "main") -> Dict[str, int]:
|
||||
def encrypted_addition_count_per_tag(self, function: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of encrypted additions per tag in the compiled program.
|
||||
"""
|
||||
@@ -740,7 +743,7 @@ class Server:
|
||||
)
|
||||
|
||||
def encrypted_addition_count_per_tag_per_parameter(
|
||||
self, function: str = "main"
|
||||
self, function: str
|
||||
) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of encrypted additions per tag per parameter in the compiled program.
|
||||
@@ -753,7 +756,7 @@ class Server:
|
||||
|
||||
# Clear Multiplication Statistics
|
||||
|
||||
def clear_multiplication_count(self, function: str = "main") -> int:
|
||||
def clear_multiplication_count(self, function: str) -> int:
|
||||
"""
|
||||
Get the number of clear multiplications in the compiled program.
|
||||
"""
|
||||
@@ -761,9 +764,7 @@ class Server:
|
||||
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
|
||||
)
|
||||
|
||||
def clear_multiplication_count_per_parameter(
|
||||
self, function: str = "main"
|
||||
) -> Dict[Parameter, int]:
|
||||
def clear_multiplication_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of clear multiplications per parameter in the compiled program.
|
||||
"""
|
||||
@@ -773,7 +774,7 @@ class Server:
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
def clear_multiplication_count_per_tag(self, function: str = "main") -> Dict[str, int]:
|
||||
def clear_multiplication_count_per_tag(self, function: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of clear multiplications per tag in the compiled program.
|
||||
"""
|
||||
@@ -782,7 +783,7 @@ class Server:
|
||||
)
|
||||
|
||||
def clear_multiplication_count_per_tag_per_parameter(
|
||||
self, function: str = "main"
|
||||
self, function: str
|
||||
) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of clear multiplications per tag per parameter in the compiled program.
|
||||
@@ -795,7 +796,7 @@ class Server:
|
||||
|
||||
# Encrypted Negation Statistics
|
||||
|
||||
def encrypted_negation_count(self, function: str = "main") -> int:
|
||||
def encrypted_negation_count(self, function: str) -> int:
|
||||
"""
|
||||
Get the number of encrypted negations in the compiled program.
|
||||
"""
|
||||
@@ -803,9 +804,7 @@ class Server:
|
||||
operations={PrimitiveOperation.ENCRYPTED_NEGATION}
|
||||
)
|
||||
|
||||
def encrypted_negation_count_per_parameter(
|
||||
self, function: str = "main"
|
||||
) -> Dict[Parameter, int]:
|
||||
def encrypted_negation_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of encrypted negations per parameter in the compiled program.
|
||||
"""
|
||||
@@ -815,7 +814,7 @@ class Server:
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
def encrypted_negation_count_per_tag(self, function: str = "main") -> Dict[str, int]:
|
||||
def encrypted_negation_count_per_tag(self, function: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of encrypted negations per tag in the compiled program.
|
||||
"""
|
||||
@@ -824,7 +823,7 @@ class Server:
|
||||
)
|
||||
|
||||
def encrypted_negation_count_per_tag_per_parameter(
|
||||
self, function: str = "main"
|
||||
self, function: str
|
||||
) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of encrypted negations per tag per parameter in the compiled program.
|
||||
@@ -834,52 +833,3 @@ class Server:
|
||||
key_types={KeyType.SECRET},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
# All Statistics
|
||||
|
||||
@property
|
||||
def statistics(self) -> Dict:
|
||||
"""
|
||||
Get all statistics of the compiled program.
|
||||
"""
|
||||
attributes = [
|
||||
"size_of_inputs",
|
||||
"size_of_outputs",
|
||||
"programmable_bootstrap_count",
|
||||
"programmable_bootstrap_count_per_parameter",
|
||||
"programmable_bootstrap_count_per_tag",
|
||||
"programmable_bootstrap_count_per_tag_per_parameter",
|
||||
"key_switch_count",
|
||||
"key_switch_count_per_parameter",
|
||||
"key_switch_count_per_tag",
|
||||
"key_switch_count_per_tag_per_parameter",
|
||||
"packing_key_switch_count",
|
||||
"packing_key_switch_count_per_parameter",
|
||||
"packing_key_switch_count_per_tag",
|
||||
"packing_key_switch_count_per_tag_per_parameter",
|
||||
"clear_addition_count",
|
||||
"clear_addition_count_per_parameter",
|
||||
"clear_addition_count_per_tag",
|
||||
"clear_addition_count_per_tag_per_parameter",
|
||||
"encrypted_addition_count",
|
||||
"encrypted_addition_count_per_parameter",
|
||||
"encrypted_addition_count_per_tag",
|
||||
"encrypted_addition_count_per_tag_per_parameter",
|
||||
"clear_multiplication_count",
|
||||
"clear_multiplication_count_per_parameter",
|
||||
"clear_multiplication_count_per_tag",
|
||||
"clear_multiplication_count_per_tag_per_parameter",
|
||||
"encrypted_negation_count",
|
||||
"encrypted_negation_count_per_parameter",
|
||||
"encrypted_negation_count_per_tag",
|
||||
"encrypted_negation_count_per_tag_per_parameter",
|
||||
"memory_usage_per_location",
|
||||
]
|
||||
output = {attribute: getattr(self, attribute)() for attribute in attributes}
|
||||
output["size_of_secret_keys"] = self.size_of_secret_keys
|
||||
output["size_of_bootstrap_keys"] = self.size_of_bootstrap_keys
|
||||
output["size_of_keyswitch_keys"] = self.size_of_keyswitch_keys
|
||||
output["p_error"] = self.p_error
|
||||
output["global_p_error"] = self.global_p_error
|
||||
output["complexity"] = self.complexity
|
||||
return output
|
||||
|
||||
17
frontends/concrete-python/concrete/fhe/compilation/status.py
Normal file
17
frontends/concrete-python/concrete/fhe/compilation/status.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Declaration of `EncryptionStatus` class.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
|
||||
from enum import Enum, unique
|
||||
|
||||
|
||||
@unique
|
||||
class EncryptionStatus(str, Enum):
|
||||
"""
|
||||
EncryptionStatus enum, to represent encryption status of parameters.
|
||||
"""
|
||||
|
||||
CLEAR = "clear"
|
||||
ENCRYPTED = "encrypted"
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
@@ -27,9 +28,11 @@ from ..dtypes import Float, Integer, SignedInteger, UnsignedInteger
|
||||
from ..representation import Graph, Node, Operation
|
||||
from ..tracing import ScalarAnnotation
|
||||
from ..values import ValueDescription
|
||||
from .artifacts import FunctionDebugArtifacts
|
||||
from .specs import ClientSpecs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .artifacts import FunctionDebugArtifacts # pragma: no cover
|
||||
|
||||
# ruff: noqa: ERA001
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -118,7 +121,7 @@ def inputset(
|
||||
def validate_input_args(
|
||||
client_specs: ClientSpecs,
|
||||
*args: Optional[Union[int, np.ndarray, List]],
|
||||
function_name: str = "main",
|
||||
function_name: str,
|
||||
) -> List[Optional[Union[int, np.ndarray]]]:
|
||||
"""Validate input arguments.
|
||||
|
||||
@@ -214,7 +217,7 @@ def validate_input_args(
|
||||
return ordered_sanitized_args
|
||||
|
||||
|
||||
def fuse(graph: Graph, artifacts: Optional[FunctionDebugArtifacts] = None):
|
||||
def fuse(graph: Graph, artifacts: Optional["FunctionDebugArtifacts"] = None):
|
||||
"""
|
||||
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
|
||||
|
||||
@@ -817,7 +820,7 @@ def convert_subgraph_to_subgraph_node(
|
||||
original_tag = terminal_node.tag
|
||||
original_created_at = terminal_node.created_at
|
||||
|
||||
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
|
||||
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node}, graph.name)
|
||||
subgraph_node = Node.generic(
|
||||
"subgraph",
|
||||
deepcopy(subgraph_variable_input_node.inputs),
|
||||
|
||||
235
frontends/concrete-python/concrete/fhe/compilation/wiring.py
Normal file
235
frontends/concrete-python/concrete/fhe/compilation/wiring.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Declaration of wiring related class.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
|
||||
from itertools import chain, product, repeat
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from ..representation import Graph
|
||||
from .composition import CompositionClause, CompositionRule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .module_compiler import FunctionDef # pragma: no cover
|
||||
|
||||
|
||||
class NotComposable:
|
||||
"""
|
||||
Composition policy that does not allow the forwarding of any output to any input.
|
||||
"""
|
||||
|
||||
def get_rules_iter(self, _funcs: List["FunctionDef"]) -> Iterable[CompositionRule]:
|
||||
"""
|
||||
Return an iterator over composition rules.
|
||||
"""
|
||||
return [] # pragma: no cover
|
||||
|
||||
|
||||
class AllComposable:
|
||||
"""
|
||||
Composition policy that allows to forward any output of the module to any of its input.
|
||||
"""
|
||||
|
||||
def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
|
||||
"""
|
||||
Return an iterator over composition rules.
|
||||
"""
|
||||
outputs = []
|
||||
for f in funcs:
|
||||
for pos, node in f.output_nodes.items():
|
||||
if node.output.is_encrypted:
|
||||
outputs.append(CompositionClause.create((f.name, pos)))
|
||||
inputs = []
|
||||
for f in funcs:
|
||||
for pos, node in f.input_nodes.items():
|
||||
if node.output.is_encrypted:
|
||||
inputs.append(CompositionClause.create((f.name, pos)))
|
||||
|
||||
return map(CompositionRule.create, product(outputs, inputs))
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class WireOutput(Protocol):
|
||||
"""
|
||||
A protocol for wire outputs.
|
||||
"""
|
||||
|
||||
def get_outputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible outputs of the wire output.
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class WireInput(Protocol):
|
||||
"""
|
||||
A protocol for wire inputs.
|
||||
"""
|
||||
|
||||
def get_inputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible inputs of the wire input.
|
||||
"""
|
||||
|
||||
|
||||
class Output(NamedTuple):
|
||||
"""
|
||||
The output of a given function of a module.
|
||||
"""
|
||||
|
||||
func: "FunctionDef"
|
||||
pos: int
|
||||
|
||||
def get_outputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible outputs of the wire output.
|
||||
"""
|
||||
return [CompositionClause(self.func.name, self.pos)]
|
||||
|
||||
|
||||
class AllOutputs(NamedTuple):
|
||||
"""
|
||||
All the encrypted outputs of a given function of a module.
|
||||
"""
|
||||
|
||||
func: "FunctionDef"
|
||||
|
||||
def get_outputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible outputs of the wire output.
|
||||
"""
|
||||
assert self.func.graph # pragma: no cover
|
||||
# No need to filter since only encrypted outputs are valid.
|
||||
return map( # pragma: no cover
|
||||
CompositionClause.create,
|
||||
zip(repeat(self.func.name), range(self.func.graph.outputs_count)),
|
||||
)
|
||||
|
||||
|
||||
class Input(NamedTuple):
|
||||
"""
|
||||
The input of a given function of a module.
|
||||
"""
|
||||
|
||||
func: "FunctionDef"
|
||||
pos: int
|
||||
|
||||
def get_inputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible inputs of the wire input.
|
||||
"""
|
||||
return [CompositionClause(self.func.name, self.pos)]
|
||||
|
||||
|
||||
class AllInputs(NamedTuple):
|
||||
"""
|
||||
All the encrypted inputs of a given function of a module.
|
||||
"""
|
||||
|
||||
func: "FunctionDef"
|
||||
|
||||
def get_inputs_iter(self) -> Iterable[CompositionClause]:
|
||||
"""
|
||||
Return an iterator over the possible inputs of the wire input.
|
||||
"""
|
||||
assert self.func.graph # pragma: no cover
|
||||
output = []
|
||||
for i in range(self.func.graph.inputs_count):
|
||||
if self.func.graph.input_nodes[i].output.is_encrypted:
|
||||
output.append(CompositionClause.create((self.func.name, i)))
|
||||
return output
|
||||
|
||||
|
||||
class Wire(NamedTuple):
|
||||
"""
|
||||
A forwarding rule between an output and an input.
|
||||
"""
|
||||
|
||||
output: WireOutput
|
||||
input: WireInput
|
||||
|
||||
def get_rules_iter(self, _) -> Iterable[CompositionRule]:
|
||||
"""
|
||||
Return an iterator over composition rules.
|
||||
"""
|
||||
return map(
|
||||
CompositionRule.create,
|
||||
product(self.output.get_outputs_iter(), self.input.get_inputs_iter()),
|
||||
)
|
||||
|
||||
|
||||
class Wired:
|
||||
"""
|
||||
Composition policy which allows the forwarding of certain outputs to certain inputs.
|
||||
"""
|
||||
|
||||
wires: Set[Wire]
|
||||
|
||||
def __init__(self, wires: Optional[Set[Wire]] = None):
|
||||
self.wires = wires if wires else set()
|
||||
|
||||
def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
|
||||
"""
|
||||
Return an iterator over composition rules.
|
||||
"""
|
||||
funcsd = {f.name: f for f in funcs}
|
||||
rules = list(chain(*[w.get_rules_iter(funcs) for w in self.wires]))
|
||||
|
||||
# We check that the given rules are legit (they concern only encrypted values)
|
||||
for rule in rules:
|
||||
if (
|
||||
not funcsd[rule.from_.func].output_nodes[rule.from_.pos].output.is_encrypted
|
||||
): # pragma: no cover
|
||||
message = f"Invalid composition rule encountered: \
|
||||
Output {rule.from_.pos} of {rule.from_.func} is not encrypted"
|
||||
raise RuntimeError(message)
|
||||
if not funcsd[rule.to.func].input_nodes[rule.to.pos].output.is_encrypted:
|
||||
message = f"Invalid composition rule encountered: \
|
||||
Input {rule.from_.pos} of {rule.from_.func} is not encrypted"
|
||||
raise RuntimeError(message)
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
class TracedOutput(NamedTuple):
|
||||
"""
|
||||
A wrapper type used to trace wiring.
|
||||
|
||||
Allows to tag an output value coming from an other module function, and binds it with
|
||||
information about its origin.
|
||||
"""
|
||||
|
||||
output_info: Output
|
||||
returned_value: Any
|
||||
|
||||
|
||||
class WireTracingContextManager:
|
||||
"""
|
||||
A context manager returned by the `wire_pipeline` method.
|
||||
|
||||
Activates wire tracing and yields an inputset that can be iterated on for tracing.
|
||||
"""
|
||||
|
||||
def __init__(self, module, inputset):
|
||||
self.module = module
|
||||
self.inputset = inputset
|
||||
|
||||
def __enter__(self):
|
||||
for func in self.module.functions.values():
|
||||
func._trace_wires = self.module.composition.wires
|
||||
return self.inputset
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _exc_tb):
|
||||
for func in self.module.functions.values():
|
||||
func._trace_wires = None
|
||||
@@ -145,7 +145,6 @@ class Converter:
|
||||
self,
|
||||
graph: Graph,
|
||||
mlir_context: MlirContext,
|
||||
name: str = "main",
|
||||
) -> MlirModule:
|
||||
"""
|
||||
Convert a computation graph to MLIR.
|
||||
@@ -157,15 +156,13 @@ class Converter:
|
||||
mlir_context (MlirContext):
|
||||
MLIR Context to use for module generation
|
||||
|
||||
name (str):
|
||||
name of the function to convert
|
||||
|
||||
Return:
|
||||
MlirModule:
|
||||
In-memory MLIR module corresponding to the graph
|
||||
"""
|
||||
|
||||
return self.convert_many({name: graph}, mlir_context)
|
||||
return self.convert_many({graph.name: graph}, mlir_context)
|
||||
|
||||
@staticmethod
|
||||
def stdout_with_ansi_support() -> bool:
|
||||
|
||||
@@ -49,8 +49,8 @@ class Graph:
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Dict[int, Node],
|
||||
output_nodes: Dict[int, Node],
|
||||
name: str,
|
||||
is_direct: bool = False,
|
||||
name: str = "main",
|
||||
location: str = "",
|
||||
):
|
||||
self.graph = graph
|
||||
@@ -1057,4 +1057,4 @@ class MultiGraphProcessor(GraphProcessor):
|
||||
"""
|
||||
Process a single graph.
|
||||
"""
|
||||
return self.apply_many({"main": graph}) # pragma: no cover
|
||||
return self.apply_many({graph.name: graph}) # pragma: no cover
|
||||
|
||||
@@ -187,7 +187,7 @@ def new_bridge(
|
||||
circuit: "fhe.Circuit",
|
||||
input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
|
||||
output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
|
||||
func_name: str = "main",
|
||||
func_name,
|
||||
) -> Bridge:
|
||||
"""Create a TFHErs bridge from a circuit.
|
||||
|
||||
@@ -199,7 +199,7 @@ def new_bridge(
|
||||
output_types (Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]): lists
|
||||
should map every output to a type, while a single element is general for all outputs.
|
||||
None means a non-tfhers type
|
||||
func_name (str, optional): name of the function to use. Defaults to "main".
|
||||
func_name (str): name of the function to use.
|
||||
|
||||
Returns:
|
||||
Bridge: TFHErs bridge
|
||||
|
||||
@@ -38,7 +38,6 @@ class Tracer:
|
||||
function: Callable,
|
||||
parameters: Dict[str, ValueDescription],
|
||||
is_direct: bool = False,
|
||||
name: str = "main",
|
||||
location: str = "",
|
||||
) -> Graph:
|
||||
"""
|
||||
@@ -55,9 +54,6 @@ class Tracer:
|
||||
is_direct (bool, default = False):
|
||||
whether the tracing is done on actual parameters or placeholders
|
||||
|
||||
name (str, default = "main"):
|
||||
the name of the function being traced
|
||||
|
||||
Returns:
|
||||
Graph:
|
||||
computation graph corresponding to `function`
|
||||
@@ -169,7 +165,9 @@ class Tracer:
|
||||
output_idx: tracer.computation for output_idx, tracer in enumerate(output_tracers)
|
||||
}
|
||||
|
||||
return Graph(graph, input_nodes, output_nodes, is_direct, name, location=location)
|
||||
return Graph(
|
||||
graph, input_nodes, output_nodes, function.__name__, is_direct, location=location
|
||||
)
|
||||
|
||||
# pylint: enable=too-many-statements
|
||||
|
||||
|
||||
@@ -35,13 +35,13 @@ def test_artifacts_export(helpers):
|
||||
assert (tmpdir / "environment.txt").exists()
|
||||
assert (tmpdir / "requirements.txt").exists()
|
||||
|
||||
assert (tmpdir / "main.txt").exists()
|
||||
assert (tmpdir / "main.parameters.txt").exists()
|
||||
assert (tmpdir / "f.txt").exists()
|
||||
assert (tmpdir / "f.parameters.txt").exists()
|
||||
|
||||
assert (tmpdir / "main.1.initial.graph.txt").exists()
|
||||
assert (tmpdir / "main.2.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "main.3.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "main.4.final.graph.txt").exists()
|
||||
assert (tmpdir / "f.1.initial.graph.txt").exists()
|
||||
assert (tmpdir / "f.2.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "f.3.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "f.4.final.graph.txt").exists()
|
||||
|
||||
assert (tmpdir / "mlir.txt").exists()
|
||||
assert (tmpdir / "client_parameters.json").exists()
|
||||
@@ -51,13 +51,13 @@ def test_artifacts_export(helpers):
|
||||
assert (tmpdir / "environment.txt").exists()
|
||||
assert (tmpdir / "requirements.txt").exists()
|
||||
|
||||
assert (tmpdir / "main.txt").exists()
|
||||
assert (tmpdir / "main.parameters.txt").exists()
|
||||
assert (tmpdir / "f.txt").exists()
|
||||
assert (tmpdir / "f.parameters.txt").exists()
|
||||
|
||||
assert (tmpdir / "main.1.initial.graph.txt").exists()
|
||||
assert (tmpdir / "main.2.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "main.3.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "main.4.final.graph.txt").exists()
|
||||
assert (tmpdir / "f.1.initial.graph.txt").exists()
|
||||
assert (tmpdir / "f.2.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "f.3.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "f.4.final.graph.txt").exists()
|
||||
|
||||
assert (tmpdir / "mlir.txt").exists()
|
||||
assert (tmpdir / "client_parameters.json").exists()
|
||||
|
||||
@@ -7,6 +7,8 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from concrete.compiler import CompilationContext
|
||||
from mlir.ir import Module as MlirModule
|
||||
|
||||
from concrete import fhe
|
||||
from concrete.fhe import Client, ClientSpecs, EvaluationKeys, LookupTable, Server, Value
|
||||
@@ -89,6 +91,8 @@ def test_circuit_feedback(helpers):
|
||||
assert isinstance(circuit.size_of_outputs, int)
|
||||
assert isinstance(circuit.p_error, float)
|
||||
assert isinstance(circuit.global_p_error, float)
|
||||
assert isinstance(circuit.mlir_module, MlirModule)
|
||||
assert isinstance(circuit.compilation_context, CompilationContext)
|
||||
|
||||
assert isinstance(circuit.memory_usage_per_location, dict)
|
||||
assert all(
|
||||
@@ -836,5 +840,4 @@ def test_simulate_encrypt_run_decrypt(helpers):
|
||||
assert isinstance(encrypted_x, int)
|
||||
assert isinstance(encrypted_y, int)
|
||||
assert hasattr(circuit, "simulator")
|
||||
assert not hasattr(circuit, "server")
|
||||
assert isinstance(encrypted_result, int)
|
||||
|
||||
@@ -450,7 +450,7 @@ def test_compiler_reset(helpers):
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
func.func @f(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
return %0 : !FHE.eint<4>
|
||||
}
|
||||
@@ -468,7 +468,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
|
||||
func.func @f(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
|
||||
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<11>, !FHE.eint<11>) -> !FHE.eint<11>
|
||||
return %0 : !FHE.eint<11>
|
||||
}
|
||||
@@ -486,7 +486,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
|
||||
func.func @f(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
|
||||
%0 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<3>>, tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>>
|
||||
return %0 : tensor<3x2x!FHE.eint<3>>
|
||||
}
|
||||
|
||||
@@ -222,39 +222,39 @@ def test_configuration_bad_fork(kwargs, expected_error, expected_message):
|
||||
"""
|
||||
|
||||
%0:
|
||||
main.%0 >= 3
|
||||
<lambda>.%0 >= 3
|
||||
%1:
|
||||
main.%1 >= 7
|
||||
<lambda>.%1 >= 7
|
||||
%2:
|
||||
main.%2 >= 2
|
||||
<lambda>.%2 >= 2
|
||||
%3:
|
||||
main.%3 >= 5
|
||||
<lambda>.%3 >= 5
|
||||
%4:
|
||||
main.%4 >= 8
|
||||
main.%3 == main.%1
|
||||
main.%1 == main.%4
|
||||
<lambda>.%4 >= 8
|
||||
<lambda>.%3 == <lambda>.%1
|
||||
<lambda>.%1 == <lambda>.%4
|
||||
|
||||
""",
|
||||
(
|
||||
"""
|
||||
|
||||
main.%0 = 3
|
||||
main.%1 = 8
|
||||
main.%2 = 2
|
||||
main.%3 = 8
|
||||
main.%4 = 8
|
||||
main.max = 8
|
||||
<lambda>.%0 = 3
|
||||
<lambda>.%1 = 8
|
||||
<lambda>.%2 = 2
|
||||
<lambda>.%3 = 8
|
||||
<lambda>.%4 = 8
|
||||
<lambda>.max = 8
|
||||
|
||||
"""
|
||||
if USE_MULTI_PRECISION
|
||||
else """
|
||||
|
||||
main.%0 = 8
|
||||
main.%1 = 8
|
||||
main.%2 = 8
|
||||
main.%3 = 8
|
||||
main.%4 = 8
|
||||
main.max = 8
|
||||
<lambda>.%0 = 8
|
||||
<lambda>.%1 = 8
|
||||
<lambda>.%2 = 8
|
||||
<lambda>.%3 = 8
|
||||
<lambda>.%4 = 8
|
||||
<lambda>.max = 8
|
||||
|
||||
"""
|
||||
),
|
||||
@@ -275,7 +275,7 @@ def test_configuration_show_bit_width_constraints_and_assignment(
|
||||
Test compiling with configuration where show_bit_width_(constraints/assignments)=True.
|
||||
"""
|
||||
|
||||
monkeypatch.setattr("concrete.fhe.compilation.compiler.get_terminal_size", lambda: 80)
|
||||
monkeypatch.setattr("concrete.fhe.compilation.artifacts.get_terminal_size", lambda: 80)
|
||||
|
||||
configuration = helpers.configuration()
|
||||
compiler = fhe.Compiler(function, encryption_status)
|
||||
@@ -289,12 +289,12 @@ def test_configuration_show_bit_width_constraints_and_assignment(
|
||||
captured.out.strip(),
|
||||
f"""
|
||||
|
||||
Bit-Width Constraints
|
||||
Bit-Width Constraints for <lambda>
|
||||
--------------------------------------------------------------------------------
|
||||
{expected_bit_width_constraints.lstrip(os.linesep).rstrip()}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Bit-Width Assignments
|
||||
Bit-Width Assignments for <lambda>
|
||||
--------------------------------------------------------------------------------
|
||||
{expected_bit_width_assignment.lstrip(os.linesep).rstrip()}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
@@ -28,42 +28,12 @@ def test_compiler_call_and_compile(helpers):
|
||||
helpers.check_execution(circuit, function, sample)
|
||||
|
||||
|
||||
def test_compiler_verbose_trace(helpers, capsys, monkeypatch):
|
||||
"""
|
||||
Test `trace` method of `compiler` decorator with verbose flag.
|
||||
"""
|
||||
|
||||
monkeypatch.setattr("concrete.fhe.compilation.compiler.get_terminal_size", lambda: 80)
|
||||
|
||||
configuration = helpers.configuration()
|
||||
artifacts = fhe.DebugArtifacts()
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
graph = function.trace(inputset, configuration, artifacts, show_graph=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == (
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
------------------------------------------------------------------
|
||||
{graph.format()}
|
||||
------------------------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_compiler_verbose_compile(helpers, capsys, monkeypatch):
|
||||
"""
|
||||
Test `compile` method of `compiler` decorator with verbose flag.
|
||||
"""
|
||||
|
||||
monkeypatch.setattr("concrete.fhe.compilation.compiler.get_terminal_size", lambda: 80)
|
||||
monkeypatch.setattr("concrete.fhe.compilation.artifacts.get_terminal_size", lambda: 80)
|
||||
|
||||
configuration = helpers.configuration()
|
||||
artifacts = fhe.DebugArtifacts()
|
||||
@@ -76,44 +46,46 @@ def test_compiler_verbose_compile(helpers, capsys, monkeypatch):
|
||||
circuit = function.compile(inputset, configuration, artifacts, verbose=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out.strip().startswith(
|
||||
f"""
|
||||
|
||||
Computation Graph
|
||||
|
||||
Computation Graph for function
|
||||
--------------------------------------------------------------------------------
|
||||
{circuit.graph.format()}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Bit-Width Constraints
|
||||
--------------------------------------------------------------------------------
|
||||
%0:
|
||||
main.%0 >= 4
|
||||
%1:
|
||||
main.%1 >= 6
|
||||
%2:
|
||||
main.%2 >= 6
|
||||
main.%0 == main.%1
|
||||
main.%1 == main.%2
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Bit-Width Assignments
|
||||
--------------------------------------------------------------------------------
|
||||
main.%0 = 6
|
||||
main.%1 = 6
|
||||
main.%2 = 6
|
||||
main.max = 6
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Bit-Width Assigned Computation Graph
|
||||
--------------------------------------------------------------------------------
|
||||
{circuit.graph.format(show_assigned_bit_widths=True)}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
MLIR
|
||||
--------------------------------------------------------------------------------
|
||||
{artifacts.mlir_to_compile}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Bit-Width Constraints for function
|
||||
--------------------------------------------------------------------------------
|
||||
%0:
|
||||
function.%0 >= 4
|
||||
%1:
|
||||
function.%1 >= 6
|
||||
%2:
|
||||
function.%2 >= 6
|
||||
function.%0 == function.%1
|
||||
function.%1 == function.%2
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Bit-Width Assignments for function
|
||||
--------------------------------------------------------------------------------
|
||||
function.%0 = 6
|
||||
function.%1 = 6
|
||||
function.%2 = 6
|
||||
function.max = 6
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Bit-Width Assigned Computation Graph for function
|
||||
--------------------------------------------------------------------------------
|
||||
{circuit.graph.format(show_assigned_bit_widths=True)}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
Optimizer
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
@@ -325,7 +297,7 @@ def test_compiler_reset(helpers):
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
func.func @compiler(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
return %0 : !FHE.eint<4>
|
||||
}
|
||||
@@ -343,7 +315,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
|
||||
func.func @compiler(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
|
||||
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<11>, !FHE.eint<11>) -> !FHE.eint<11>
|
||||
return %0 : !FHE.eint<11>
|
||||
}
|
||||
@@ -361,7 +333,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
|
||||
func.func @compiler(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
|
||||
%0 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<3>>, tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>>
|
||||
return %0 : tensor<3x2x!FHE.eint<3>>
|
||||
}
|
||||
|
||||
@@ -263,7 +263,7 @@ def test_round_bit_pattern_no_overflow_protection(helpers):
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.eint<11> {
|
||||
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.eint<11> {
|
||||
%0 = "FHE.round"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<5>
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 16, 64, 144, 256, 400, 576, 784, 1024, 1296, 1600, 1936, 2304, 2704, 3136, 3600, 4096, 3600, 3136, 2704, 2304, 1936, 1600, 1296, 1024, 784, 576, 400, 256, 144, 64, 16]> : tensor<32xi64>
|
||||
@@ -277,7 +277,7 @@ module {
|
||||
else """
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<11>) -> !FHE.eint<11> {
|
||||
func.func @function(%arg0: !FHE.esint<11>) -> !FHE.eint<11> {
|
||||
%c16_i12 = arith.constant 16 : i12
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c16_i12) : (!FHE.esint<11>, i12) -> !FHE.esint<11>
|
||||
%1 = "FHE.round"(%0) : (!FHE.esint<11>) -> !FHE.esint<5>
|
||||
@@ -312,7 +312,7 @@ def test_round_bit_pattern_identity(helpers):
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<6>) -> !FHE.esint<7> {
|
||||
func.func @function(%arg0: !FHE.esint<6>) -> !FHE.esint<7> {
|
||||
%0 = "FHE.round"(%arg0) : (!FHE.esint<6>) -> !FHE.esint<4>
|
||||
%cst = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, -32, -28, -24, -20, -16, -12, -8, -4]> : tensor<16xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.esint<4>, tensor<16xi64>) -> !FHE.esint<7>
|
||||
@@ -326,7 +326,7 @@ module {
|
||||
else """
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
|
||||
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
|
||||
%c2_i8 = arith.constant 2 : i8
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c2_i8) : (!FHE.esint<7>, i8) -> !FHE.esint<7>
|
||||
%1 = "FHE.round"(%0) : (!FHE.esint<7>) -> !FHE.esint<4>
|
||||
|
||||
@@ -382,7 +382,7 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen(
|
||||
assert (dtype.decode(concrete_encoded_result) == function(*sample)).all()
|
||||
|
||||
###### TFHErs Encryption ######################################################
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main")
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="<lambda>")
|
||||
|
||||
# serialize key
|
||||
_, key_path = tempfile.mkstemp()
|
||||
@@ -577,7 +577,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen(
|
||||
assert (dtype.decode(concrete_encoded_result) == function(*sample)).all()
|
||||
|
||||
###### TFHErs Encryption ######################################################
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main")
|
||||
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="<lambda>")
|
||||
|
||||
# serialize key
|
||||
_, key_path = tempfile.mkstemp()
|
||||
|
||||
@@ -250,7 +250,7 @@ def test_truncate_bit_pattern_identity(helpers, pytestconfig):
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
|
||||
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
|
||||
%0 = "FHE.lsb"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<7>
|
||||
%1 = "FHE.sub_eint"(%arg0, %0) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
|
||||
%2 = "FHE.reinterpret_precision"(%1) : (!FHE.esint<7>) -> !FHE.esint<6>
|
||||
@@ -267,7 +267,7 @@ module {
|
||||
else """
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
|
||||
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
|
||||
%0 = "FHE.lsb"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<7>
|
||||
%1 = "FHE.sub_eint"(%arg0, %0) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
|
||||
%2 = "FHE.reinterpret_precision"(%1) : (!FHE.esint<7>) -> !FHE.esint<6>
|
||||
|
||||
@@ -1143,7 +1143,7 @@ def test_converter_bad_convert(
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<6> {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<6> {
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<6>
|
||||
return %0 : !FHE.eint<6>
|
||||
}
|
||||
@@ -1161,7 +1161,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>>
|
||||
return %0 : tensor<3x2x!FHE.eint<6>>
|
||||
}
|
||||
@@ -1179,7 +1179,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<4>>, %arg1: tensor<2x!FHE.eint<4>>) -> !FHE.eint<7> {
|
||||
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<4>>, %arg1: tensor<2x!FHE.eint<4>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<4>>, tensor<2x!FHE.eint<4>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
@@ -1197,7 +1197,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x!FHE.eint<7>>
|
||||
}
|
||||
@@ -1215,7 +1215,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<10x!FHE.eint<4>>, %arg1: tensor<10x!FHE.eint<4>>) -> !FHE.eint<1> {
|
||||
func.func @"<lambda>"(%arg0: tensor<10x!FHE.eint<4>>, %arg1: tensor<10x!FHE.eint<4>>) -> !FHE.eint<1> {
|
||||
%0 = "FHELinalg.to_signed"(%arg0) : (tensor<10x!FHE.eint<4>>) -> tensor<10x!FHE.esint<4>>
|
||||
%1 = "FHELinalg.to_signed"(%arg1) : (tensor<10x!FHE.eint<4>>) -> tensor<10x!FHE.esint<4>>
|
||||
%2 = "FHELinalg.sub_eint"(%0, %1) : (tensor<10x!FHE.esint<4>>, tensor<10x!FHE.esint<4>>) -> tensor<10x!FHE.esint<4>>
|
||||
@@ -1242,7 +1242,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<2>>, %arg1: !FHE.eint<6>) -> tensor<2x!FHE.eint<6>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<2>>, %arg1: !FHE.eint<6>) -> tensor<2x!FHE.eint<6>> {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
|
||||
%0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<6>>
|
||||
@@ -1263,7 +1263,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%c8_i8 = arith.constant 8 : i8
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c8_i8) : (!FHE.eint<7>, i8) -> !FHE.eint<7>
|
||||
@@ -1289,7 +1289,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%c12_i8 = arith.constant 12 : i8
|
||||
%0 = "FHE.sub_eint_int"(%arg0, %c12_i8) : (!FHE.eint<7>, i8) -> !FHE.eint<7>
|
||||
@@ -1317,7 +1317,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<3>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<3>> {
|
||||
%c12_i5 = arith.constant 12 : i5
|
||||
%from_elements = tensor.from_elements %c12_i5 : tensor<1xi5>
|
||||
%0 = "FHELinalg.sub_eint_int"(%arg0, %from_elements) : (tensor<2x!FHE.eint<4>>, tensor<1xi5>) -> tensor<2x!FHE.eint<4>>
|
||||
@@ -1345,7 +1345,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%c4_i8 = arith.constant 4 : i8
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c4_i8) : (!FHE.eint<7>, i8) -> !FHE.eint<7>
|
||||
@@ -1371,7 +1371,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<3>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<3>> {
|
||||
%c12_i5 = arith.constant 12 : i5
|
||||
%from_elements = tensor.from_elements %c12_i5 : tensor<1xi5>
|
||||
%0 = "FHELinalg.sub_eint_int"(%arg0, %from_elements) : (tensor<2x!FHE.eint<4>>, tensor<1xi5>) -> tensor<2x!FHE.eint<4>>
|
||||
@@ -1399,7 +1399,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<5>>) -> tensor<2x!FHE.eint<3>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<5>>) -> tensor<2x!FHE.eint<3>> {
|
||||
%cst = arith.constant dense<[0, 1]> : tensor<2xindex>
|
||||
%cst_0 = arith.constant dense<[[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10]]> : tensor<2x32xi64>
|
||||
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<2x!FHE.eint<5>>, tensor<2x32xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<3>>
|
||||
@@ -1420,7 +1420,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<8>) -> (!FHE.eint<5>, !FHE.eint<8>) {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<8>) -> (!FHE.eint<5>, !FHE.eint<8>) {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%c4_i9 = arith.constant 4 : i9
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c4_i9) : (!FHE.eint<8>, i9) -> !FHE.eint<8>
|
||||
@@ -1444,7 +1444,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
func.func @"<lambda>"(%arg0: !FHE.esint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%c4_i8 = arith.constant 4 : i8
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c4_i8) : (!FHE.esint<7>, i8) -> !FHE.esint<7>
|
||||
@@ -1471,7 +1471,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
func.func @"<lambda>"(%arg0: !FHE.esint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%c13_i8 = arith.constant 13 : i8
|
||||
%0 = "FHE.add_eint_int"(%arg0, %c13_i8) : (!FHE.esint<7>, i8) -> !FHE.esint<7>
|
||||
@@ -1500,7 +1500,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.esint<8>) -> (!FHE.esint<5>, !FHE.eint<8>) {
|
||||
func.func @"<lambda>"(%arg0: !FHE.esint<8>) -> (!FHE.esint<5>, !FHE.eint<8>) {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%c4_i9 = arith.constant 4 : i9
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c4_i9) : (!FHE.esint<8>, i9) -> !FHE.esint<8>
|
||||
@@ -1527,7 +1527,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%c3_i3 = arith.constant 3 : i3
|
||||
%c2_i7 = arith.constant 2 : i7
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
|
||||
@@ -1552,7 +1552,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%c3_i3 = arith.constant 3 : i3
|
||||
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
|
||||
@@ -1574,7 +1574,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%c3_i3 = arith.constant 3 : i3
|
||||
%c2_i7 = arith.constant 2 : i7
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
|
||||
@@ -1599,7 +1599,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%c3_i3 = arith.constant 3 : i3
|
||||
%c2_i7 = arith.constant 2 : i7
|
||||
%0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
|
||||
@@ -1624,7 +1624,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%c3_i3 = arith.constant 3 : i3
|
||||
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
|
||||
@@ -1646,7 +1646,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%c3_i3 = arith.constant 3 : i3
|
||||
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
|
||||
@@ -1669,7 +1669,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i4) -> !FHE.eint<4> {
|
||||
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i4) -> !FHE.eint<4> {
|
||||
%0 = arith.index_cast %arg1 : i4 to index
|
||||
%extracted = tensor.extract %arg0[%0] : tensor<8x!FHE.eint<4>>
|
||||
return %extracted : !FHE.eint<4>
|
||||
@@ -1690,7 +1690,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i5) -> !FHE.eint<4> {
|
||||
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i5) -> !FHE.eint<4> {
|
||||
%c8_i5 = arith.constant 8 : i5
|
||||
%c0_i5 = arith.constant 0 : i5
|
||||
%0 = arith.cmpi slt, %arg1, %c0_i5 : i5
|
||||
@@ -1720,7 +1720,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i4) -> !FHE.eint<4> {
|
||||
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i4) -> !FHE.eint<4> {
|
||||
%c8_i5 = arith.constant 8 : i5
|
||||
%0 = arith.extsi %arg1 : i4 to i5
|
||||
%1 = arith.cmpi sge, %0, %c8_i5 : i5
|
||||
@@ -1752,7 +1752,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i5) -> !FHE.eint<4> {
|
||||
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i5) -> !FHE.eint<4> {
|
||||
%c8_i5 = arith.constant 8 : i5
|
||||
%c0_i5 = arith.constant 0 : i5
|
||||
%0 = arith.cmpi slt, %arg1, %c0_i5 : i5
|
||||
@@ -1790,7 +1790,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi4>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi4>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
%0 = arith.index_cast %arg1 : tensor<3x2xi4> to tensor<3x2xindex>
|
||||
%1 = "FHELinalg.fancy_index"(%arg0, %0) : (tensor<8x!FHE.eint<4>>, tensor<3x2xindex>) -> tensor<3x2x!FHE.eint<4>>
|
||||
return %1 : tensor<3x2x!FHE.eint<4>>
|
||||
@@ -1811,7 +1811,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi5>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi5>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
%c8_i5 = arith.constant 8 : i5
|
||||
%generated = tensor.generate {
|
||||
^bb0(%arg2: index, %arg3: index):
|
||||
@@ -1846,7 +1846,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi4>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi4>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
%c8_i5 = arith.constant 8 : i5
|
||||
%0 = arith.extsi %arg1 : tensor<3x2xi4> to tensor<3x2xi5>
|
||||
%c0 = arith.constant 0 : index
|
||||
@@ -1889,7 +1889,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi5>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi5>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
%c8_i5 = arith.constant 8 : i5
|
||||
%generated = tensor.generate {
|
||||
^bb0(%arg2: index, %arg3: index):
|
||||
@@ -1930,7 +1930,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<102x70x104x!FHE.eint<6>>) -> tensor<102x70x104x!FHE.eint<5>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<102x70x104x!FHE.eint<6>>) -> tensor<102x70x104x!FHE.eint<5>> {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31]> : tensor<64xi64>
|
||||
%0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<102x70x104x!FHE.eint<6>>, tensor<64xi64>) -> tensor<102x70x104x!FHE.eint<5>>
|
||||
@@ -1949,7 +1949,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x!FHE.eint<8>>) -> !FHE.eint<8> {
|
||||
func.func @"<lambda>"(%arg0: tensor<3x!FHE.eint<8>>) -> !FHE.eint<8> {
|
||||
%cst = arith.constant dense<[1, 0, 2]> : tensor<3xi3>
|
||||
%0 = "FHELinalg.dot_eint_int"(%arg0, %cst) : (tensor<3x!FHE.eint<8>>, tensor<3xi3>) -> !FHE.eint<8>
|
||||
return %0 : !FHE.eint<8>
|
||||
@@ -1967,7 +1967,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x2x!FHE.eint<8>>) -> tensor<2x2x!FHE.eint<8>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<2x2x!FHE.eint<8>>) -> tensor<2x2x!FHE.eint<8>> {
|
||||
%cst = arith.constant dense<[[1, 0], [2, 3]]> : tensor<2x2xi3>
|
||||
%0 = "FHELinalg.matmul_eint_int"(%arg0, %cst) : (tensor<2x2x!FHE.eint<8>>, tensor<2x2xi3>) -> tensor<2x2x!FHE.eint<8>>
|
||||
return %0 : tensor<2x2x!FHE.eint<8>>
|
||||
@@ -2016,7 +2016,7 @@ def test_converter_convert_multi_precision(
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
|
||||
return %0 : !FHE.eint<6>
|
||||
}
|
||||
@@ -2033,7 +2033,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<6>>, %arg1: tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<3x2x!FHE.eint<6>>, %arg1: tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<6>>, tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>>
|
||||
return %0 : tensor<3x2x!FHE.eint<6>>
|
||||
}
|
||||
@@ -2050,7 +2050,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<7>>, %arg1: tensor<2x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<7>>, %arg1: tensor<2x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<7>>, tensor<2x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
@@ -2067,7 +2067,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<7>>, %arg1: tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<3x2x!FHE.eint<7>>, %arg1: tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> {
|
||||
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<7>>, tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>>
|
||||
return %0 : tensor<3x4x!FHE.eint<7>>
|
||||
}
|
||||
@@ -2084,7 +2084,7 @@ module {
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<6>>, %arg1: !FHE.eint<6>) -> tensor<2x!FHE.eint<6>> {
|
||||
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<6>>, %arg1: !FHE.eint<6>) -> tensor<2x!FHE.eint<6>> {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%c16_i7 = arith.constant 16 : i7
|
||||
%from_elements = tensor.from_elements %c16_i7 : tensor<1xi7>
|
||||
@@ -2134,7 +2134,7 @@ def test_converter_convert_single_precision(function, parameters, expected_mlir,
|
||||
"""
|
||||
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<9>, %arg1: !FHE.eint<9>) -> !FHE.eint<9> {
|
||||
func.func @"<lambda>"(%arg0: !FHE.eint<9>, %arg1: !FHE.eint<9>) -> !FHE.eint<9> {
|
||||
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<9>, !FHE.eint<9>) -> !FHE.eint<9>
|
||||
return %0 : !FHE.eint<9>
|
||||
}
|
||||
@@ -2197,7 +2197,7 @@ def test_converter_process_multi_precision(function, parameters, expected_graph,
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
graph = compiler.trace(inputset, configuration)
|
||||
|
||||
GraphConverter(configuration).process({"main": graph})
|
||||
GraphConverter(configuration).process({"<lambda>": graph})
|
||||
for node in graph.query_nodes():
|
||||
if "original_bit_width" in node.properties:
|
||||
del node.properties["original_bit_width"]
|
||||
@@ -2239,7 +2239,7 @@ def test_converter_process_single_precision(function, parameters, expected_graph
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
graph = compiler.trace(inputset, configuration)
|
||||
|
||||
GraphConverter(configuration).process({"main": graph})
|
||||
GraphConverter(configuration).process({"<lambda>": graph})
|
||||
for node in graph.query_nodes():
|
||||
if "original_bit_width" in node.properties:
|
||||
del node.properties["original_bit_width"]
|
||||
|
||||
Reference in New Issue
Block a user