refactor(frontends): unify circuits and modules

This commit is contained in:
Alexandre Péré
2024-09-09 11:58:31 +02:00
committed by Alexandre Péré
parent 52636e47c6
commit d9b34f13d0
33 changed files with 932 additions and 1551 deletions

View File

@@ -1310,6 +1310,15 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
clientParameters, inputId, circuitName);
return encryption.getVariance();
})
.def("function_list",
[](::concretelang::clientlib::ClientParameters &clientParameters) {
std::vector<std::string> result;
for (auto circuit :
clientParameters.programInfo.asReader().getCircuits()) {
result.push_back(circuit.getName());
}
return result;
})
.def("output_signs",
[](::concretelang::clientlib::ClientParameters &clientParameters) {
std::vector<bool> result;

View File

@@ -37,13 +37,12 @@ class ClientParameters(WrapperCpp):
)
super().__init__(client_parameters)
def input_keyid_at(self, input_idx: int, circuit_name: str = "main") -> int:
def input_keyid_at(self, input_idx: int, circuit_name: str) -> int:
"""Get the keyid of a selected encrypted input in a given circuit.
Args:
input_idx (int): index of the input in the circuit.
circuit_name (str, optional): name of the circuit containing the desired input.
Defaults to "main".
circuit_name (str): name of the circuit containing the desired input.
Raises:
TypeError: if arguments aren't of expected types
@@ -59,13 +58,12 @@ class ClientParameters(WrapperCpp):
)
return self.cpp().input_keyid_at(input_idx, circuit_name)
def input_variance_at(self, input_idx: int, circuit_name: str = "main") -> float:
def input_variance_at(self, input_idx: int, circuit_name: str) -> float:
"""Get the variance of a selected encrypted input in a given circuit.
Args:
input_idx (int): index of the input in the circuit.
circuit_name (str, optional): name of the circuit containing the desired input.
Defaults to "main".
circuit_name (str): name of the circuit containing the desired input.
Raises:
TypeError: if arguments aren't of expected types
@@ -97,6 +95,14 @@ class ClientParameters(WrapperCpp):
"""
return self.cpp().output_signs()
def function_list(self) -> List[str]:
"""Return the list of function names.
Returns:
List[str]: list of the names of the functions.
"""
return self.cpp().function_list()
def serialize(self) -> bytes:
"""Serialize the ClientParameters.

View File

@@ -116,7 +116,7 @@ class ClientSupport(WrapperCpp):
client_parameters: ClientParameters,
keyset: KeySet,
args: List[Union[int, np.ndarray]],
circuit_name: str = "main",
circuit_name: str,
) -> PublicArguments:
"""Prepare arguments for encrypted computation.
@@ -172,7 +172,7 @@ class ClientSupport(WrapperCpp):
client_parameters: ClientParameters,
keyset: KeySet,
public_result: PublicResult,
circuit_name: str = "main",
circuit_name: str,
) -> Union[int, np.ndarray]:
"""Decrypt a public result using the keyset.

View File

@@ -236,7 +236,7 @@ class LibrarySupport(WrapperCpp):
self,
library_compilation_result: LibraryCompilationResult,
simulation: bool,
circuit_name: str = "main",
circuit_name: str,
) -> LibraryLambda:
"""Load the server lambda for a given circuit from the library compilation result.

View File

@@ -41,7 +41,7 @@ class SimulatedValueDecrypter(WrapperCpp):
@staticmethod
# pylint: disable=arguments-differ
def new(client_parameters: ClientParameters, circuit_name: str = "main"):
def new(client_parameters: ClientParameters, circuit_name: str):
"""
Create a value decrypter.
"""

View File

@@ -41,7 +41,7 @@ class SimulatedValueExporter(WrapperCpp):
@staticmethod
# pylint: disable=arguments-differ
def new(
client_parameters: ClientParameters, circuitName: str = "main"
client_parameters: ClientParameters, circuitName: str
) -> "SimulatedValueExporter":
"""
Create a value exporter.

View File

@@ -42,7 +42,7 @@ class ValueExporter(WrapperCpp):
@staticmethod
# pylint: disable=arguments-differ
def new(
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main"
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str
) -> "ValueExporter":
"""
Create a value exporter.

View File

@@ -24,7 +24,7 @@ def assert_result(result, expected_result):
assert np.all(result == expected_result)
def run(engine, args, compilation_result, keyset_cache, circuit_name="main"):
def run(engine, args, compilation_result, keyset_cache, circuit_name):
"""Execute engine on the given arguments.
Perform required loading, encryption, execution, and decryption."""
@@ -219,7 +219,7 @@ def test_lib_compile_reload_and_run(mlir_input, args, expected_result, keyset_ca
# Here don't save compilation result, reload
engine.compile(mlir_input)
compilation_result = engine.reload()
result = run(engine, args, compilation_result, keyset_cache)
result = run(engine, args, compilation_result, keyset_cache, "main")
# Check result
assert_result(result, expected_result)
shutil.rmtree(artifact_dir)
@@ -398,7 +398,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache):
def test_crt_decomposition_feedback():
mlir = """
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
%tlu = arith.constant dense<60000> : tensor<65536xi64>
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<16>, tensor<65536xi64>) -> (!FHE.eint<16>)

View File

@@ -5,7 +5,7 @@ Glue the compilation process together.
from .artifacts import DebugArtifacts, FunctionDebugArtifacts, ModuleDebugArtifacts
from .circuit import Circuit
from .client import Client
from .compiler import Compiler, EncryptionStatus
from .compiler import Compiler
from .composition import CompositionClause, CompositionPolicy, CompositionRule
from .configuration import (
DEFAULT_GLOBAL_P_ERROR,
@@ -22,19 +22,10 @@ from .configuration import (
)
from .keys import Keys
from .module import FheFunction, FheModule
from .module_compiler import (
AllComposable,
AllInputs,
AllOutputs,
FunctionDef,
Input,
ModuleCompiler,
NotComposable,
Output,
Wire,
Wired,
)
from .module_compiler import FunctionDef, ModuleCompiler
from .server import Server
from .specs import ClientSpecs
from .utils import inputset
from .status import EncryptionStatus
from .utils import get_terminal_size, inputset
from .value import Value
from .wiring import AllComposable, AllInputs, AllOutputs, Input, NotComposable, Output, Wire, Wired

View File

@@ -10,6 +10,8 @@ from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from ..representation import Graph
from .configuration import Configuration
from .utils import get_terminal_size
if TYPE_CHECKING: # pragma: no cover
from .module import ExecutionRt
@@ -18,6 +20,213 @@ if TYPE_CHECKING: # pragma: no cover
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
class DebugManager:
"""
A debug manager, allowing streamlined debugging.
"""
configuration: Configuration
begin_call: Callable
def __init__(self, config: Configuration):
self.configuration = config
is_first = [True]
def begin_call():
if is_first[0]:
print()
is_first[0] = False
self.begin_call = begin_call
def debug_table(self, title: str, activate: bool = True):
"""
Return a context manager that prints a table around what is printed inside the scope.
"""
# pylint: disable=missing-class-docstring
class DebugTableCm:
def __init__(self, title):
self.title = title
self.columns = get_terminal_size()
def __enter__(self):
print(f"{self.title}")
print("-" * self.columns)
def __exit__(self, _exc_type, _exc_value, _exc_tb):
print("-" * self.columns)
print()
class EmptyCm:
def __enter__(self):
pass
def __exit__(self, _exc_type, _exc_value, _exc_tb):
pass
if activate:
self.begin_call()
return DebugTableCm(title)
return EmptyCm()
def show_graph(self) -> bool:
"""
Tell if the configuration involves showing graph.
"""
return (
self.configuration.show_graph
if self.configuration.show_graph is not None
else self.configuration.verbose
)
def show_bit_width_constraints(self) -> bool:
"""
Tell if the configuration involves showing bitwidth constraints.
"""
return (
self.configuration.show_bit_width_constraints
if self.configuration.show_bit_width_constraints is not None
else self.configuration.verbose
)
def show_bit_width_assignments(self) -> bool:
"""
Tell if the configuration involves showing bitwidth assignments.
"""
return (
self.configuration.show_bit_width_assignments
if self.configuration.show_bit_width_assignments is not None
else self.configuration.verbose
)
def show_assigned_graph(self) -> bool:
"""
Tell if the configuration involves showing assigned graph.
"""
return (
self.configuration.show_assigned_graph
if self.configuration.show_assigned_graph is not None
else self.configuration.verbose
)
def show_mlir(self) -> bool:
"""
Tell if the configuration involves showing mlir.
"""
return (
self.configuration.show_mlir
if self.configuration.show_mlir is not None
else self.configuration.verbose
)
def show_optimizer(self) -> bool:
"""
Tell if the configuration involves showing optimizer.
"""
return (
self.configuration.show_optimizer
if self.configuration.show_optimizer is not None
else self.configuration.verbose
)
def show_statistics(self) -> bool:
"""
Tell if the configuration involves showing statistics.
"""
return (
self.configuration.show_statistics
if self.configuration.show_statistics is not None
else self.configuration.verbose
)
def debug_computation_graph(self, name, function_graph):
"""
Print computation graph if configuration tells so.
"""
if (
self.show_graph()
or self.show_bit_width_constraints()
or self.show_bit_width_assignments()
or self.show_assigned_graph()
or self.show_mlir()
or self.show_optimizer()
or self.show_statistics()
):
if self.show_graph():
with self.debug_table(f"Computation Graph for {name}"):
print(function_graph.format())
def debug_bit_width_constaints(self, name, function_graph):
"""
Print bitwidth constraints if configuration tells so.
"""
if self.show_bit_width_constraints():
with self.debug_table(f"Bit-Width Constraints for {name}"):
print(function_graph.format_bit_width_constraints())
def debug_bit_width_assignments(self, name, function_graph):
"""
Print bitwidth assignments if configuration tells so.
"""
if self.show_bit_width_assignments():
with self.debug_table(f"Bit-Width Assignments for {name}"):
print(function_graph.format_bit_width_assignments())
def debug_assigned_graph(self, name, function_graph):
"""
Print assigned graphs if configuration tells so.
"""
if self.show_assigned_graph():
with self.debug_table(f"Bit-Width Assigned Computation Graph for {name}"):
print(function_graph.format(show_assigned_bit_widths=True))
def debug_mlir(self, mlir_str):
"""
Print mlir if configuration tells so.
"""
if self.show_mlir():
with self.debug_table("MLIR"):
print(mlir_str)
def debug_statistics(self, module):
"""
Print statistics if configuration tells so.
"""
if self.show_statistics():
def pretty(d, indent=0): # pragma: no cover
if indent > 0:
print("{")
for key, value in d.items():
if isinstance(value, dict) and len(value) == 0:
continue
print(" " * indent + str(key) + ": ", end="")
if isinstance(value, dict):
pretty(value, indent + 1)
else:
print(value)
if indent > 0:
print(" " * (indent - 1) + "}")
with self.debug_table("Statistics"):
pretty(module.statistics)
class FunctionDebugArtifacts:
"""
An object containing debug artifacts for a certain function in an fhe module.
@@ -236,88 +445,18 @@ class DebugArtifacts:
"""
module_artifacts: ModuleDebugArtifacts
_client_parameters: Optional[bytes]
def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
self.module_artifacts = ModuleDebugArtifacts(["main"], output_directory)
self._client_parameters = None
def add_source_code(self, function: Union[str, Callable]):
"""
Add source code of the function being compiled.
Args:
function (Union[str, Callable]):
either the source code of the function or the function itself
"""
self.module_artifacts.functions["main"].add_source_code(function)
def add_parameter_encryption_status(self, name: str, encryption_status: str):
"""
Add parameter encryption status of a parameter of the function being compiled.
Args:
name (str):
name of the parameter
encryption_status (str):
encryption status of the parameter
"""
self.module_artifacts.functions["main"].add_parameter_encryption_status(
name, encryption_status
)
def add_graph(self, name: str, graph: Graph):
"""
Add a representation of the function being compiled.
Args:
name (str):
name of the graph (e.g., initial, optimized, final)
graph (Graph):
a representation of the function being compiled
"""
self.module_artifacts.functions["main"].add_graph(name, graph)
def add_mlir_to_compile(self, mlir: str):
"""
Add textual representation of the resulting MLIR.
Args:
mlir (str):
textual representation of the resulting MLIR
"""
self.module_artifacts.add_mlir_to_compile(mlir)
def add_client_parameters(self, client_parameters: bytes):
"""
Add client parameters used.
Args:
client_parameters (bytes): client parameters
"""
self._client_parameters = client_parameters
self.module_artifacts = ModuleDebugArtifacts([], output_directory)
def export(self):
"""
Export the collected information to `self.output_directory`.
"""
# This is a quick fix before we refactor compiler and module_compiler
# to use the same abstraction.
class _ModuleDebugArtifacts(ModuleDebugArtifacts):
client_parameters = self._client_parameters
self.module_artifacts.__class__ = _ModuleDebugArtifacts
self.module_artifacts.export()
@property
def output_directory(self) -> Path:
def output_directory(self) -> Path: # pragma: no cover
"""
Return the directory to export artifacts to.
"""

View File

@@ -5,25 +5,18 @@ Declaration of `Circuit` class.
# pylint: disable=import-error,no-member,no-name-in-module
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import (
CompilationContext,
Parameter,
SimulatedValueDecrypter,
SimulatedValueExporter,
)
from concrete.compiler import CompilationContext, Parameter
from mlir.ir import Module as MlirModule
from ..internal.utils import assert_that
from ..representation import Graph
from .client import Client
from .composition import CompositionRule
from .configuration import Configuration
from .keys import Keys
from .module import FheFunction, FheModule
from .server import Server
from .utils import validate_input_args
from .value import Value
# pylint: enable=import-error,no-member,no-name-in-module
@@ -34,40 +27,20 @@ class Circuit:
Circuit class, to combine computation graph, mlir, client and server into a single object.
"""
configuration: Configuration
_module: FheModule
_name: str
graph: Graph
mlir_module: MlirModule
compilation_context: CompilationContext
composition_rules: Optional[List[CompositionRule]]
def __init__(self, module: FheModule):
assert module.function_count == 1
self._name = next(iter(module.functions().keys()))
self._module = module
client: Client
server: Server
simulator: Server
def __init__(
self,
graph: Graph,
mlir: MlirModule,
compilation_context: CompilationContext,
configuration: Optional[Configuration] = None,
composition_rules: Optional[Iterable[CompositionRule]] = None,
):
self.configuration = configuration if configuration is not None else Configuration()
self.composition_rules = list(composition_rules) if composition_rules else []
self.graph = graph
self.mlir_module = mlir
self.compilation_context = compilation_context
if self.configuration.fhe_simulation:
self.enable_fhe_simulation()
if self.configuration.fhe_execution:
self.enable_fhe_execution()
@property
def _function(self) -> FheFunction:
return getattr(self._module, self._name)
def __str__(self):
return self.graph.format()
return self._function.graph.format()
def draw(
self,
@@ -100,7 +73,7 @@ class Circuit:
path to the drawing
"""
return self.graph.draw(horizontal=horizontal, save_to=save_to, show=show)
return self._function.graph.draw(horizontal=horizontal, save_to=save_to, show=show)
@property
def mlir(self) -> str:
@@ -109,42 +82,19 @@ class Circuit:
Returns:
str: textual representation of the MLIR module
"""
return str(self.mlir_module).strip()
return str(self._module.mlir_module).strip()
def enable_fhe_simulation(self):
"""
Enable FHE simulation.
"""
if not hasattr(self, "simulator"):
self.simulator = Server.create(
self.mlir_module,
self.configuration,
is_simulated=True,
compilation_context=self.compilation_context,
composition_rules=self.composition_rules,
)
self._module.simulation_runtime.init()
def enable_fhe_execution(self):
"""
Enable FHE execution.
"""
if not hasattr(self, "server"):
self.server = Server.create(
self.mlir_module,
self.configuration,
compilation_context=self.compilation_context,
composition_rules=self.composition_rules,
)
keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:
assert_that(self.configuration.enable_unsafe_features)
assert_that(self.configuration.insecure_key_cache_location is not None)
keyset_cache_directory = self.configuration.insecure_key_cache_location
self.client = Client(self.server.client_specs, keyset_cache_directory)
self._module.execution_runtime.init() # pragma: no cover
def simulate(self, *args: Any) -> Any:
"""
@@ -158,58 +108,21 @@ class Circuit:
Any:
result of the simulation
"""
if not hasattr(self, "simulator"): # pragma: no cover
self.enable_fhe_simulation()
ordered_validated_args = validate_input_args(self.simulator.client_specs, *args)
exporter = SimulatedValueExporter.new(self.simulator.client_specs.client_parameters)
exported = [
(
None
if arg is None
else Value(
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
if isinstance(arg, np.ndarray) and arg.shape != ()
else exporter.export_scalar(position, int(arg))
)
)
for position, arg in enumerate(ordered_validated_args)
]
results = self.simulator.run(*exported)
if not isinstance(results, tuple):
results = (results,)
decrypter = SimulatedValueDecrypter.new(self.simulator.client_specs.client_parameters)
decrypted = tuple(
decrypter.decrypt(position, result.inner) for position, result in enumerate(results)
)
return decrypted if len(decrypted) != 1 else decrypted[0]
return self._function.simulate(*args)
@property
def keys(self) -> Keys:
"""
Get the keys of the circuit.
"""
if not hasattr(self, "client"): # pragma: no cover
self.enable_fhe_execution()
return self.client.keys
return self._module.keys
@keys.setter
def keys(self, new_keys: Keys):
"""
Set the keys of the circuit.
"""
if not hasattr(self, "client"): # pragma: no cover
self.enable_fhe_execution()
self.client.keys = new_keys
self._module.keys = new_keys
def keygen(
self, force: bool = False, seed: Optional[int] = None, encryption_seed: Optional[int] = None
@@ -227,11 +140,7 @@ class Circuit:
encryption_seed (Optional[int], default = None):
seed for encryption randomness
"""
if not hasattr(self, "client"): # pragma: no cover
self.enable_fhe_execution()
self.client.keygen(force, seed, encryption_seed)
self._module.keygen(force=force, seed=seed, encryption_seed=encryption_seed)
def encrypt(
self,
@@ -248,14 +157,7 @@ class Circuit:
Optional[Union[Value, Tuple[Optional[Value], ...]]]:
encrypted argument(s) for evaluation
"""
if self.configuration.simulate_encrypt_run_decrypt:
return args if len(args) != 1 else args[0] # type: ignore
if not hasattr(self, "client"): # pragma: no cover
self.enable_fhe_execution()
return self.client.encrypt(*args)
return self._function.encrypt(*args)
def run(
self,
@@ -273,14 +175,7 @@ class Circuit:
result(s) of evaluation
"""
if self.configuration.simulate_encrypt_run_decrypt:
return self.simulate(*args)
if not hasattr(self, "server"): # pragma: no cover
self.enable_fhe_execution()
self.keygen(force=False)
return self.server.run(*args, evaluation_keys=self.client.evaluation_keys)
return self._function.run(*args)
def decrypt(
self,
@@ -298,13 +193,7 @@ class Circuit:
decrypted result(s) of evaluation
"""
if self.configuration.simulate_encrypt_run_decrypt:
return results if len(results) != 1 else results[0] # type: ignore
if not hasattr(self, "client"): # pragma: no cover
self.enable_fhe_execution()
return self.client.decrypt(*results)
return self._function.decrypt(*results)
def encrypt_run_decrypt(self, *args: Any) -> Any:
"""
@@ -319,101 +208,81 @@ class Circuit:
clear result of homomorphic evaluation
"""
return self.decrypt(self.run(self.encrypt(*args)))
return self._function.encrypt_run_decrypt(*args)
def cleanup(self):
"""
Cleanup the temporary library output directory.
"""
if hasattr(self, "server"): # pragma: no cover
self.server.cleanup()
self._module.cleanup()
# Properties
def _property(self, name: str) -> Any:
"""
Get a property of the circuit by name.
Args:
name (str):
name of the property
Returns:
Any:
statistic
"""
if hasattr(self, "simulator"):
return getattr(self.simulator, name) # pragma: no cover
if not hasattr(self, "server"):
self.enable_fhe_execution() # pragma: no cover
return getattr(self.server, name)
@property
def size_of_secret_keys(self) -> int:
"""
Get size of the secret keys of the circuit.
"""
return self._property("size_of_secret_keys") # pragma: no cover
return self._module.size_of_secret_keys # pragma: no cover
@property
def size_of_bootstrap_keys(self) -> int:
"""
Get size of the bootstrap keys of the circuit.
"""
return self._property("size_of_bootstrap_keys") # pragma: no cover
return self._module.size_of_bootstrap_keys # pragma: no cover
@property
def size_of_keyswitch_keys(self) -> int:
"""
Get size of the key switch keys of the circuit.
"""
return self._property("size_of_keyswitch_keys") # pragma: no cover
return self._module.size_of_keyswitch_keys # pragma: no cover
@property
def size_of_inputs(self) -> int:
"""
Get size of the inputs of the circuit.
"""
return self._property("size_of_inputs")() # pragma: no cover
return self._function.size_of_inputs # pragma: no cover
@property
def size_of_outputs(self) -> int:
"""
Get size of the outputs of the circuit.
"""
return self._property("size_of_outputs")() # pragma: no cover
return self._function.size_of_outputs # pragma: no cover
@property
def p_error(self) -> int:
"""
Get probability of error for each simple TLU (on a scalar).
"""
return self._property("p_error") # pragma: no cover
return self._module.p_error # pragma: no cover
@property
def global_p_error(self) -> int:
"""
Get the probability of having at least one simple TLU error during the entire execution.
"""
return self._property("global_p_error") # pragma: no cover
return self._module.p_error # pragma: no cover
@property
def complexity(self) -> float:
"""
Get complexity of the circuit.
"""
return self._property("complexity") # pragma: no cover
return self._module.complexity # pragma: no cover
@property
def memory_usage_per_location(self) -> Dict[str, int]:
"""
Get the memory usage of operations in the circuit per location.
"""
return self._property("memory_usage_per_location")() # pragma: no cover
return self._function.execution_runtime.val.server.memory_usage_per_location(
self._name
) # pragma: no cover
# Programmable Bootstrap Statistics
@@ -422,30 +291,28 @@ class Circuit:
"""
Get the number of programmable bootstraps in the circuit.
"""
return self._property("programmable_bootstrap_count")() # pragma: no cover
return self._function.programmable_bootstrap_count # pragma: no cover
@property
def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of programmable bootstraps per bit width in the circuit.
"""
return self._property("programmable_bootstrap_count_per_parameter")() # pragma: no cover
return self._function.programmable_bootstrap_count_per_parameter # pragma: no cover
@property
def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of programmable bootstraps per tag in the circuit.
"""
return self._property("programmable_bootstrap_count_per_tag")() # pragma: no cover
return self._function.programmable_bootstrap_count_per_tag # pragma: no cover
@property
def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[int, int]]:
"""
Get the number of programmable bootstraps per tag per bit width in the circuit.
"""
return self._property(
"programmable_bootstrap_count_per_tag_per_parameter"
)() # pragma: no cover
return self._function.programmable_bootstrap_count_per_tag_per_parameter # pragma: no cover
# Key Switch Statistics
@@ -454,28 +321,28 @@ class Circuit:
"""
Get the number of key switches in the circuit.
"""
return self._property("key_switch_count")() # pragma: no cover
return self._function.key_switch_count # pragma: no cover
@property
def key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of key switches per parameter in the circuit.
"""
return self._property("key_switch_count_per_parameter")() # pragma: no cover
return self._function.key_switch_count_per_parameter # pragma: no cover
@property
def key_switch_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of key switches per tag in the circuit.
"""
return self._property("key_switch_count_per_tag")() # pragma: no cover
return self._function.key_switch_count_per_tag # pragma: no cover
@property
def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of key switches per tag per parameter in the circuit.
"""
return self._property("key_switch_count_per_tag_per_parameter")() # pragma: no cover
return self._function.key_switch_count_per_tag_per_parameter # pragma: no cover
# Packing Key Switch Statistics
@@ -484,30 +351,28 @@ class Circuit:
"""
Get the number of packing key switches in the circuit.
"""
return self._property("packing_key_switch_count")() # pragma: no cover
return self._function.packing_key_switch_count # pragma: no cover
@property
def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of packing key switches per parameter in the circuit.
"""
return self._property("packing_key_switch_count_per_parameter")() # pragma: no cover
return self._function.packing_key_switch_count_per_parameter # pragma: no cover
@property
def packing_key_switch_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of packing key switches per tag in the circuit.
"""
return self._property("packing_key_switch_count_per_tag")() # pragma: no cover
return self._function.packing_key_switch_count_per_tag # pragma: no cover
@property
def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of packing key switches per tag per parameter in the circuit.
"""
return self._property(
"packing_key_switch_count_per_tag_per_parameter"
)() # pragma: no cover
return self._function.packing_key_switch_count_per_tag_per_parameter # pragma: no cover
# Clear Addition Statistics
@@ -516,28 +381,28 @@ class Circuit:
"""
Get the number of clear additions in the circuit.
"""
return self._property("clear_addition_count")() # pragma: no cover
return self._function.clear_addition_count # pragma: no cover
@property
def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of clear additions per parameter in the circuit.
"""
return self._property("clear_addition_count_per_parameter")() # pragma: no cover
return self._function.clear_addition_count_per_parameter # pragma: no cover
@property
def clear_addition_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of clear additions per tag in the circuit.
"""
return self._property("clear_addition_count_per_tag")() # pragma: no cover
return self._function.clear_addition_count_per_tag # pragma: no cover
@property
def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of clear additions per tag per parameter in the circuit.
"""
return self._property("clear_addition_count_per_tag_per_parameter")() # pragma: no cover
return self._function.clear_addition_count_per_tag_per_parameter # pragma: no cover
# Encrypted Addition Statistics
@@ -546,30 +411,28 @@ class Circuit:
"""
Get the number of encrypted additions in the circuit.
"""
return self._property("encrypted_addition_count")() # pragma: no cover
return self._function.encrypted_addition_count # pragma: no cover
@property
def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of encrypted additions per parameter in the circuit.
"""
return self._property("encrypted_addition_count_per_parameter")() # pragma: no cover
return self._function.encrypted_addition_count_per_parameter # pragma: no cover
@property
def encrypted_addition_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of encrypted additions per tag in the circuit.
"""
return self._property("encrypted_addition_count_per_tag")() # pragma: no cover
return self._function.encrypted_addition_count_per_tag # pragma: no cover
@property
def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of encrypted additions per tag per parameter in the circuit.
"""
return self._property(
"encrypted_addition_count_per_tag_per_parameter"
)() # pragma: no cover
return self._function.encrypted_addition_count_per_tag_per_parameter # pragma: no cover
# Clear Multiplication Statistics
@@ -578,30 +441,28 @@ class Circuit:
"""
Get the number of clear multiplications in the circuit.
"""
return self._property("clear_multiplication_count")() # pragma: no cover
return self._function.clear_multiplication_count # pragma: no cover
@property
def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of clear multiplications per parameter in the circuit.
"""
return self._property("clear_multiplication_count_per_parameter")() # pragma: no cover
return self._function.clear_multiplication_count_per_parameter # pragma: no cover
@property
def clear_multiplication_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of clear multiplications per tag in the circuit.
"""
return self._property("clear_multiplication_count_per_tag")() # pragma: no cover
return self._function.clear_multiplication_count_per_tag # pragma: no cover
@property
def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of clear multiplications per tag per parameter in the circuit.
"""
return self._property(
"clear_multiplication_count_per_tag_per_parameter"
)() # pragma: no cover
return self._function.clear_multiplication_count_per_tag_per_parameter # pragma: no cover
# Encrypted Negation Statistics
@@ -610,36 +471,85 @@ class Circuit:
"""
Get the number of encrypted negations in the circuit.
"""
return self._property("encrypted_negation_count")() # pragma: no cover
return self._function.encrypted_negation_count # pragma: no cover
@property
def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of encrypted negations per parameter in the circuit.
"""
return self._property("encrypted_negation_count_per_parameter")() # pragma: no cover
return self._function.encrypted_negation_count_per_parameter # pragma: no cover
@property
def encrypted_negation_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of encrypted negations per tag in the circuit.
"""
return self._property("encrypted_negation_count_per_tag")() # pragma: no cover
return self._function.encrypted_negation_count_per_tag # pragma: no cover
@property
def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of encrypted negations per tag per parameter in the circuit.
"""
return self._property(
"encrypted_negation_count_per_tag_per_parameter"
)() # pragma: no cover
return self._function.encrypted_negation_count_per_tag_per_parameter # pragma: no cover
# All Statistics
@property
def statistics(self) -> Dict:
def statistics(self) -> Dict: # pragma: no cover
"""
Get all statistics of the circuit.
"""
return self._property("statistics") # pragma: no cover
mod_stats = self._module.statistics
func_stats = mod_stats.pop("functions")[self._name]
return mod_stats | func_stats
@property
def configuration(self) -> Configuration:
"""
Return the circuit configuration.
"""
return self._module.configuration
@property
def graph(self) -> Graph:
"""
Return the circuit graph.
"""
return self._function.graph
@property
def mlir_module(self) -> MlirModule:
"""
Return the circuit mlir module.
"""
return self._module.mlir_module
@property
def compilation_context(self) -> CompilationContext:
"""
Return the circuit compilation context.
"""
return self._module.compilation_context
@property
def client(self) -> Client:
"""
Return the circuit client.
"""
return self._module.client
@property
def server(self) -> Server:
"""
Return the circuit server.
"""
return self._module.server
@property
def simulator(self) -> Server:
"""
Return the circuit simulator.
"""
return self._module.simulator

View File

@@ -120,7 +120,7 @@ class Client:
def encrypt(
self,
*args: Optional[Union[int, np.ndarray, List]],
function_name: str = "main",
function_name: Optional[str] = None,
) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]:
"""
Encrypt argument(s) to for evaluation.
@@ -136,6 +136,15 @@ class Client:
encrypted argument(s) for evaluation
"""
if function_name is None:
functions = self.specs.client_parameters.function_list()
if len(functions) == 1:
function_name = functions[0]
else: # pragma: no cover
msg = "The client contains more than one functions. \
Provide a `function_name` keyword argument to disambiguate."
raise TypeError(msg)
ordered_sanitized_args = validate_input_args(self.specs, *args, function_name=function_name)
self.keygen(force=False)
@@ -160,7 +169,7 @@ class Client:
def decrypt(
self,
*results: Union[Value, Tuple[Value, ...]],
function_name: str = "main",
function_name: Optional[str] = None,
) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]:
"""
Decrypt result(s) of evaluation.
@@ -176,6 +185,15 @@ class Client:
decrypted result(s) of evaluation
"""
if function_name is None:
functions = self.specs.client_parameters.function_list()
if len(functions) == 1:
function_name = functions[0]
else: # pragma: no cover
msg = "The client contains more than one functions. \
Provide a `function_name` keyword argument to disambiguate."
raise TypeError(msg)
flattened_results: List[Value] = []
for result in results:
if isinstance(result, tuple): # pragma: no cover

View File

@@ -4,60 +4,31 @@ Declaration of `Compiler` class.
# pylint: disable=import-error,no-name-in-module
import inspect
import os
import traceback
from copy import deepcopy
from enum import Enum, unique
from itertools import product, repeat
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
import numpy as np
from concrete.compiler import CompilationContext
from ..extensions import AutoRounder, AutoTruncator
from ..mlir import GraphConverter
from ..representation import Graph
from ..tracing import Tracer
from ..values import ValueDescription
from .artifacts import DebugArtifacts
from .artifacts import DebugArtifacts, FunctionDebugArtifacts, ModuleDebugArtifacts
from .circuit import Circuit
from .composition import CompositionClause, CompositionRule
from .composition import CompositionPolicy
from .configuration import Configuration
from .utils import fuse, get_terminal_size
from .module_compiler import FunctionDef, ModuleCompiler
from .status import EncryptionStatus
from .wiring import AllComposable, NotComposable, TracedOutput
# pylint: enable=import-error,no-name-in-module
@unique
class EncryptionStatus(str, Enum):
"""
EncryptionStatus enum, to represent encryption status of parameters.
"""
CLEAR = "clear"
ENCRYPTED = "encrypted"
class Compiler:
"""
Compiler class, to glue the compilation pipeline.
"""
function: Callable
parameter_encryption_statuses: Dict[str, EncryptionStatus]
configuration: Configuration
artifacts: Optional[DebugArtifacts]
location: str
inputset: List[Any]
graph: Optional[Graph]
compilation_context: CompilationContext
_is_direct: bool
_parameter_values: Dict[str, ValueDescription]
_module_compiler: ModuleCompiler
_function_name: str
@staticmethod
def assemble(
@@ -97,81 +68,38 @@ class Compiler:
name: "encrypted" if value.is_encrypted else "clear"
for name, value in parameter_values.items()
},
composition=(
AllComposable()
if (configuration.composable if configuration is not None else False)
else NotComposable()
),
)
# pylint: disable=protected-access
compiler._is_direct = True
compiler._parameter_values = parameter_values
compiler._func_def._is_direct = True
compiler._func_def._parameter_values = parameter_values
# pylint: enable=protected-access
return compiler.compile(None, configuration, artifacts, **kwargs)
return compiler.compile([], configuration, artifacts, **kwargs)
def __init__(
self,
function: Callable,
parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]],
composition: Optional[Union[NotComposable, AllComposable]] = None,
):
signature = inspect.signature(function)
missing_args = list(signature.parameters)
for arg in parameter_encryption_statuses.keys():
if arg in signature.parameters:
missing_args.remove(arg)
if len(missing_args) != 0:
parameter_str = repr(missing_args[0])
for arg in missing_args[1:-1]:
parameter_str += f", {repr(arg)}"
if len(missing_args) != 1:
parameter_str += f" and {repr(missing_args[-1])}"
message = (
f"Encryption status{'es' if len(missing_args) > 1 else ''} "
f"of parameter{'s' if len(missing_args) > 1 else ''} "
f"{parameter_str} of function '{function.__name__}' "
f"{'are' if len(missing_args) > 1 else 'is'} not provided"
)
raise ValueError(message)
additional_args = list(parameter_encryption_statuses)
for arg in signature.parameters.keys():
if arg in parameter_encryption_statuses:
additional_args.remove(arg)
if len(additional_args) != 0:
parameter_str = repr(additional_args[0])
for arg in additional_args[1:-1]:
parameter_str += f", {repr(arg)}"
if len(additional_args) != 1:
parameter_str += f" and {repr(additional_args[-1])}"
message = (
f"Encryption status{'es' if len(additional_args) > 1 else ''} "
f"of {parameter_str} {'are' if len(additional_args) > 1 else 'is'} provided but "
f"{'they are' if len(additional_args) > 1 else 'it is'} not a parameter "
f"of function '{function.__name__}'"
)
raise ValueError(message)
self.function = function # type: ignore
self.parameter_encryption_statuses = {
param: EncryptionStatus(status.lower())
for param, status in parameter_encryption_statuses.items()
}
self.configuration = Configuration()
self.artifacts = None
self.inputset = []
self.graph = None
self.compilation_context = CompilationContext.new()
self._is_direct = False
self._parameter_values = {}
self.location = (
f"{self.function.__code__.co_filename}:{self.function.__code__.co_firstlineno}"
if composition is None:
composition = NotComposable()
assert isinstance(composition, CompositionPolicy)
func = FunctionDef(
function=function, parameter_encryption_statuses=parameter_encryption_statuses
)
self._module_compiler = ModuleCompiler([func], composition)
self._function_name = function.__name__
@property
def _func_def(self) -> FunctionDef:
return getattr(self._module_compiler, self._function_name)
def __call__(
self,
@@ -182,138 +110,10 @@ class Compiler:
np.integer,
np.floating,
np.ndarray,
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...],
TracedOutput,
Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray, TracedOutput], ...],
]:
if len(kwargs) != 0:
message = f"Calling function '{self.function.__name__}' with kwargs is not supported"
raise RuntimeError(message)
sample = args[0] if len(args) == 1 else args
if self.graph is None:
self._trace(sample)
assert self.graph is not None
self.inputset.append(sample)
return self.graph(*args)
def _trace(self, sample: Union[Any, Tuple[Any, ...]]):
"""
Trace the function and fuse the resulting graph with a sample input.
Args:
sample (Union[Any, Tuple[Any, ...]]):
sample to use for tracing
"""
if self.artifacts is not None:
self.artifacts.add_source_code(self.function)
for param, encryption_status in self.parameter_encryption_statuses.items():
self.artifacts.add_parameter_encryption_status(param, encryption_status)
parameters = {
param: ValueDescription.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
for arg, (param, status) in zip(
(
sample
if len(self.parameter_encryption_statuses) > 1 or isinstance(sample, tuple)
else (sample,)
),
self.parameter_encryption_statuses.items(),
)
}
self.graph = Tracer.trace(self.function, parameters, location=self.location)
if self.artifacts is not None:
self.artifacts.add_graph("initial", self.graph)
fuse(
self.graph,
(self.artifacts.module_artifacts.functions["main"] if self.artifacts else None),
)
def _evaluate(
self,
action: str,
inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]],
):
"""
Trace, fuse, measure bounds, and update values in the resulting graph in one go.
Args:
action (str):
action being performed (e.g., "trace", "compile")
inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]):
optional inputset to extend accumulated inputset before bounds measurement
"""
if self._is_direct:
self.graph = Tracer.trace(
self.function, self._parameter_values, is_direct=True, location=self.location
)
if self.artifacts is not None:
self.artifacts.add_graph("initial", self.graph) # pragma: no cover
fuse(
self.graph,
(self.artifacts.module_artifacts.functions["main"] if self.artifacts else None),
)
if self.artifacts is not None:
self.artifacts.add_graph("final", self.graph) # pragma: no cover
return
if inputset is not None:
previous_inputset_length = len(self.inputset)
for index, sample in enumerate(iter(inputset)):
self.inputset.append(sample)
if not isinstance(sample, tuple):
sample = (sample,)
if len(sample) != len(self.parameter_encryption_statuses):
self.inputset = self.inputset[:previous_inputset_length]
expected = (
"a single value"
if len(self.parameter_encryption_statuses) == 1
else f"a tuple of {len(self.parameter_encryption_statuses)} values"
)
actual = (
"a single value" if len(sample) == 1 else f"a tuple of {len(sample)} values"
)
message = (
f"Input #{index} of your inputset is not well formed "
f"(expected {expected} got {actual})"
)
raise ValueError(message)
if self.configuration.auto_adjust_rounders:
AutoRounder.adjust(self.function, self.inputset)
if self.configuration.auto_adjust_truncators:
AutoTruncator.adjust(self.function, self.inputset)
if self.graph is None:
try:
first_sample = next(iter(self.inputset))
except StopIteration as error:
message = (
f"{action} function '{self.function.__name__}' "
f"without an inputset is not supported"
)
raise RuntimeError(message) from error
self._trace(first_sample)
assert self.graph is not None
bounds = self.graph.measure_bounds(self.inputset)
self.graph.update_with_bounds(bounds)
if self.artifacts is not None:
self.artifacts.add_graph("final", self.graph)
return self._func_def(*args, **kwargs)
def trace(
self,
@@ -342,78 +142,22 @@ class Compiler:
Graph:
computation graph representing the function prior to MLIR conversion
"""
old_configuration = deepcopy(self.configuration)
old_artifacts = deepcopy(self.artifacts)
if configuration is not None:
self.configuration = configuration
if len(kwargs) != 0:
self.configuration = self.configuration.fork(**kwargs)
self.artifacts = (
artifacts
art = (
artifacts.module_artifacts.functions.get(self._function_name, FunctionDebugArtifacts())
if artifacts is not None
else (
DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
)
else FunctionDebugArtifacts()
)
conf = (
configuration
if configuration is not None
else self._module_compiler.default_configuration
)
if len(kwargs) != 0:
conf = conf.fork(**kwargs)
try:
self._evaluate("Tracing", inputset)
assert self.graph is not None
if self.configuration.verbose or self.configuration.show_graph:
graph = self.graph.format()
longest_line = max(len(line) for line in graph.split("\n"))
try: # pragma: no cover
# this branch cannot be covered
# because `os.get_terminal_size()`
# raises an exception during tests
columns, _ = os.get_terminal_size()
if columns == 0: # noqa: SIM108
columns = min(longest_line, 80)
else:
columns = min(longest_line, columns)
except OSError: # pragma: no cover
columns = min(longest_line, 80)
print()
print("Computation Graph")
print("-" * columns)
print(graph)
print("-" * columns)
print()
return self.graph
except Exception: # pragma: no cover
# this branch is reserved for unexpected issues and hence it shouldn't be tested
# if it could be tested, we would have fixed the underlying issue
# if the user desires so,
# we need to export all the information we have about the compilation
if self.configuration.dump_artifacts_on_unexpected_failures:
assert self.artifacts is not None
self.artifacts.export()
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise
finally:
self.configuration = old_configuration
self.artifacts = old_artifacts
self._func_def.evaluate("Tracing", inputset, conf, art)
assert self._func_def.graph is not None
return self._func_def.graph
# pylint: disable=too-many-branches,too-many-statements
@@ -445,240 +189,21 @@ class Compiler:
compiled circuit
"""
old_configuration = deepcopy(self.configuration)
old_artifacts = deepcopy(self.artifacts)
if configuration is not None:
self.configuration = configuration
if len(kwargs) != 0:
self.configuration = self.configuration.fork(**kwargs)
self.artifacts = (
artifacts
if artifacts is not None
else (
DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
)
art = artifacts.module_artifacts if artifacts is not None else ModuleDebugArtifacts()
conf = (
configuration
if configuration is not None
else self._module_compiler.default_configuration
)
if len(kwargs) != 0:
conf = conf.fork(**kwargs)
try:
self._evaluate("Compiling", inputset)
assert self.graph is not None
show_graph = (
self.configuration.show_graph
if self.configuration.show_graph is not None
else self.configuration.verbose
)
show_bit_width_constraints = (
self.configuration.show_bit_width_constraints
if self.configuration.show_bit_width_constraints is not None
else self.configuration.verbose
)
show_bit_width_assignments = (
self.configuration.show_bit_width_assignments
if self.configuration.show_bit_width_assignments is not None
else self.configuration.verbose
)
show_assigned_graph = (
self.configuration.show_assigned_graph
if self.configuration.show_assigned_graph is not None
else self.configuration.verbose
)
show_mlir = (
self.configuration.show_mlir
if self.configuration.show_mlir is not None
else self.configuration.verbose
)
show_optimizer = (
self.configuration.show_optimizer
if self.configuration.show_optimizer is not None
else self.configuration.verbose
)
show_statistics = (
self.configuration.show_statistics
if self.configuration.show_statistics is not None
else self.configuration.verbose
)
columns = get_terminal_size()
is_first = True
if (
show_graph
or show_bit_width_constraints
or show_bit_width_assignments
or show_assigned_graph
or show_mlir
or show_optimizer
or show_statistics
):
if show_graph:
if is_first: # pragma: no cover
print()
is_first = False
print("Computation Graph")
print("-" * columns)
print(self.graph.format())
print("-" * columns)
print()
# We generate the composition rules if needed:
composition_rules = []
if self.configuration.composable:
compo_froms = map(
CompositionClause.create,
zip(repeat(self.graph.name), range(len(self.graph.output_nodes))),
)
compo_tos = map(
CompositionClause.create,
zip(repeat(self.graph.name), range(len(self.graph.input_nodes))),
)
composition_rules = list(
map(CompositionRule.create, product(compo_froms, compo_tos))
)
# in-memory MLIR module
mlir_context = self.compilation_context.mlir_context()
mlir_module = GraphConverter(self.configuration, composition_rules).convert(
self.graph, mlir_context
)
# textual representation of the MLIR module
mlir_str = str(mlir_module).strip()
if self.artifacts is not None:
self.artifacts.add_mlir_to_compile(mlir_str)
if show_bit_width_constraints:
if is_first: # pragma: no cover
print()
is_first = False
print("Bit-Width Constraints")
print("-" * columns)
print(self.graph.format_bit_width_constraints())
print("-" * columns)
print()
if show_bit_width_assignments:
if is_first: # pragma: no cover
print()
is_first = False
print("Bit-Width Assignments")
print("-" * columns)
print(self.graph.format_bit_width_assignments())
print("-" * columns)
print()
if show_assigned_graph:
if is_first: # pragma: no cover
print()
is_first = False
print("Bit-Width Assigned Computation Graph")
print("-" * columns)
print(self.graph.format(show_assigned_bit_widths=True))
print("-" * columns)
print()
if show_mlir:
if is_first: # pragma: no cover
print()
is_first = False
print("MLIR")
print("-" * columns)
print(mlir_str)
print("-" * columns)
print()
if show_optimizer:
if is_first: # pragma: no cover
print()
is_first = False
print("Optimizer")
print("-" * columns)
circuit = Circuit(
self.graph,
mlir_module,
self.compilation_context,
self.configuration,
composition_rules,
)
if hasattr(circuit, "client"):
client_parameters = circuit.client.specs.client_parameters
if self.artifacts is not None:
self.artifacts.add_client_parameters(client_parameters.serialize())
if show_optimizer:
print("-" * columns)
print()
if show_statistics:
if is_first: # pragma: no cover
print()
print("Statistics")
print("-" * columns)
def pretty(d, indent=0): # pragma: no cover
if indent > 0:
print("{")
for key, value in d.items():
if isinstance(value, dict) and len(value) == 0:
continue
print(" " * indent + str(key) + ": ", end="")
if isinstance(value, dict):
pretty(value, indent + 1)
else:
print(value)
if indent > 0:
print(" " * (indent - 1) + "}")
pretty(circuit.statistics)
print("-" * columns)
print()
except Exception: # pragma: no cover
# this branch is reserved for unexpected issues and hence it shouldn't be tested
# if it could be tested, we would have fixed the underlying issue
# if the user desires so,
# we need to export all the information we have about the compilation
if self.configuration.dump_artifacts_on_unexpected_failures:
assert self.artifacts is not None
self.artifacts.export()
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())
raise
finally:
self.configuration = old_configuration
self.artifacts = old_artifacts
return circuit
if conf.composable:
self._module_compiler.composition = AllComposable()
fhe_module = self._module_compiler.compile(
{self._function_name: inputset}, configuration=conf, module_artifacts=art
)
return Circuit(fhe_module)
# pylint: enable=too-many-branches,too-many-statements
@@ -686,5 +211,6 @@ class Compiler:
"""
Reset the compiler so that another compilation with another inputset can be performed.
"""
fresh_compiler = Compiler(self.function, self.parameter_encryption_statuses)
fdef = self._module_compiler.functions[self._function_name]
fresh_compiler = Compiler(fdef.function, fdef.parameter_encryption_statuses)
self.__dict__.update(fresh_compiler.__dict__)

View File

@@ -12,9 +12,11 @@ from ..tracing.typing import ScalarAnnotation
from ..values import ValueDescription
from .artifacts import DebugArtifacts
from .circuit import Circuit
from .compiler import Compiler, EncryptionStatus
from .compiler import Compiler
from .configuration import Configuration
from .module_compiler import AllComposable, CompositionPolicy, FunctionDef, ModuleCompiler
from .module_compiler import CompositionPolicy, FunctionDef, ModuleCompiler
from .status import EncryptionStatus
from .wiring import AllComposable
def circuit(
@@ -151,7 +153,9 @@ class Compilable:
compiled circuit
"""
return self.compiler.compile(inputset, configuration, artifacts, **kwargs)
return self.compiler.compile(
inputset if inputset is not None else [], configuration, artifacts, **kwargs
)
def reset(self):
"""

View File

@@ -166,12 +166,10 @@ class FheFunction:
decrypted = tuple(
decrypter.decrypt(position, result.inner) for position, result in enumerate(results)
)
return decrypted if len(decrypted) != 1 else decrypted[0]
def encrypt(
self,
*args: Optional[Union[int, np.ndarray, List]],
self, *args: Optional[Union[int, np.ndarray, List]]
) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]:
"""
Encrypt argument(s) to for evaluation.
@@ -693,6 +691,7 @@ class FheModule:
"""
Get size of the key switch keys of the module.
"""
return self.execution_runtime.val.server.size_of_keyswitch_keys # pragma: no cover
@property
@@ -758,12 +757,26 @@ class FheModule:
return self.execution_runtime.val.server
@property
def client(self) -> Optional[Client]:
def client(self) -> Client:
"""
Returns the execution client object tied to the module.
"""
return self.execution_runtime.val.client
@property
def simulator(self) -> Server:
"""
Returns the simulation server object tied to the module.
"""
return self.simulation_runtime.val.server
@property
def function_count(self) -> int:
"""
Returns the number of functions in the module.
"""
return len(self.graphs)
def __getattr__(self, item):
if item not in list(self.graphs.keys()):
error = f"No attribute {item}"

View File

@@ -6,23 +6,8 @@ Declaration of `MultiCompiler` class.
import inspect
import traceback
from copy import deepcopy
from itertools import chain, product, repeat
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Protocol,
Set,
Tuple,
Union,
runtime_checkable,
)
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import numpy as np
from concrete.compiler import CompilationContext
@@ -32,12 +17,13 @@ from ..mlir import GraphConverter
from ..representation import Graph
from ..tracing import Tracer
from ..values import ValueDescription
from .artifacts import FunctionDebugArtifacts, ModuleDebugArtifacts
from .compiler import EncryptionStatus
from .composition import CompositionClause, CompositionPolicy, CompositionRule
from .artifacts import DebugManager, FunctionDebugArtifacts, ModuleDebugArtifacts
from .composition import CompositionPolicy
from .configuration import Configuration
from .module import FheModule
from .utils import fuse, get_terminal_size
from .status import EncryptionStatus
from .utils import fuse
from .wiring import Input, Output, TracedOutput, Wire, Wired, WireTracingContextManager
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
@@ -49,13 +35,14 @@ class FunctionDef:
An object representing the definition of a function as used in an fhe module.
"""
name: str
function: Callable
parameter_encryption_statuses: Dict[str, EncryptionStatus]
inputset: List[Any]
graph: Optional[Graph]
_parameter_values: Dict[str, ValueDescription]
location: str
_is_direct: bool
_parameter_values: Dict[str, ValueDescription]
_trace_wires: Optional[Set["Wire"]]
def __init__(
@@ -112,13 +99,18 @@ class FunctionDef:
}
self.inputset = []
self.graph = None
self.name = function.__name__
self._is_direct = False
self._parameter_values = {}
self.location = (
f"{self.function.__code__.co_filename}:{self.function.__code__.co_firstlineno}"
)
self._trace_wires = None
@property
def name(self) -> str:
"""Return the name of the function."""
return self.function.__name__
def trace(
self,
sample: Union[Any, Tuple[Any, ...]],
@@ -151,7 +143,7 @@ class FunctionDef:
)
}
self.graph = Tracer.trace(self.function, parameters, name=self.name, location=self.location)
self.graph = Tracer.trace(self.function, parameters, location=self.location)
if artifacts is not None:
artifacts.add_graph("initial", self.graph)
@@ -181,6 +173,21 @@ class FunctionDef:
artifact object to store informations in
"""
if self._is_direct:
self.graph = Tracer.trace(
self.function,
self._parameter_values,
is_direct=True,
location=self.location,
)
artifacts.add_graph("initial", self.graph) # pragma: no cover
fuse(
self.graph,
artifacts,
)
artifacts.add_graph("final", self.graph) # pragma: no cover
return
if inputset is not None:
previous_inputset_length = len(self.inputset)
for index, sample in enumerate(iter(inputset)):
@@ -320,424 +327,6 @@ class FunctionDef:
return output[0] if len(output) == 1 else output
class NotComposable:
"""
Composition policy that does not allow the forwarding of any output to any input.
"""
def get_rules_iter(self, _) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
return [] # pragma: no cover
class AllComposable:
"""
Composition policy that allows to forward any output of the module to any of its input.
"""
def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
outputs = []
for f in funcs:
for pos, node in f.output_nodes.items():
if node.output.is_encrypted:
outputs.append(CompositionClause.create((f.name, pos)))
inputs = []
for f in funcs:
for pos, node in f.input_nodes.items():
if node.output.is_encrypted:
inputs.append(CompositionClause.create((f.name, pos)))
return map(CompositionRule.create, product(outputs, inputs))
@runtime_checkable
class WireOutput(Protocol):
"""
A protocol for wire outputs.
"""
def get_outputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible outputs of the wire output.
"""
@runtime_checkable
class WireInput(Protocol):
"""
A protocol for wire inputs.
"""
def get_inputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible inputs of the wire input.
"""
class Output(NamedTuple):
"""
The output of a given function of a module.
"""
func: FunctionDef
pos: int
def get_outputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible outputs of the wire output.
"""
return [CompositionClause(self.func.name, self.pos)]
class AllOutputs(NamedTuple):
"""
All the encrypted outputs of a given function of a module.
"""
func: FunctionDef
def get_outputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible outputs of the wire output.
"""
assert self.func.graph # pragma: no cover
# No need to filter since only encrypted outputs are valid.
return map( # pragma: no cover
CompositionClause.create,
zip(repeat(self.func.name), range(self.func.graph.outputs_count)),
)
class Input(NamedTuple):
"""
The input of a given function of a module.
"""
func: FunctionDef
pos: int
def get_inputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible inputs of the wire input.
"""
return [CompositionClause(self.func.name, self.pos)]
class AllInputs(NamedTuple):
"""
All the encrypted inputs of a given function of a module.
"""
func: FunctionDef
def get_inputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible inputs of the wire input.
"""
assert self.func.graph # pragma: no cover
output = []
for i in range(self.func.graph.inputs_count):
if self.func.graph.input_nodes[i].output.is_encrypted:
output.append(CompositionClause.create((self.func.name, i)))
return output
class Wire(NamedTuple):
"""
A forwarding rule between an output and an input.
"""
output: WireOutput
input: WireInput
def get_rules_iter(self, _) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
return map(
CompositionRule.create,
product(self.output.get_outputs_iter(), self.input.get_inputs_iter()),
)
class Wired:
"""
Composition policy which allows the forwarding of certain outputs to certain inputs.
"""
wires: Set[Wire]
def __init__(self, wires: Optional[Set[Wire]] = None):
self.wires = wires if wires else set()
def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
funcsd = {f.name: f for f in funcs}
rules = list(chain(*[w.get_rules_iter(funcs) for w in self.wires]))
# We check that the given rules are legit (they concern only encrypted values)
for rule in rules:
if (
not funcsd[rule.from_.func].output_nodes[rule.from_.pos].output.is_encrypted
): # pragma: no cover
message = f"Invalid composition rule encountered: \
Output {rule.from_.pos} of {rule.from_.func} is not encrypted"
raise RuntimeError(message)
if not funcsd[rule.to.func].input_nodes[rule.to.pos].output.is_encrypted:
message = f"Invalid composition rule encountered: \
Input {rule.from_.pos} of {rule.from_.func} is not encrypted"
raise RuntimeError(message)
return rules
class DebugManager:
"""
A debug manager, allowing streamlined debugging.
"""
configuration: Configuration
begin_call: Callable
def __init__(self, config: Configuration):
self.configuration = config
is_first = [True]
def begin_call():
if is_first[0]:
print()
is_first[0] = False
self.begin_call = begin_call
def debug_table(self, title: str, activate: bool = True):
"""
Return a context manager that prints a table around what is printed inside the scope.
"""
# pylint: disable=missing-class-docstring
class DebugTableCm:
def __init__(self, title):
self.title = title
self.columns = get_terminal_size()
def __enter__(self):
print(f"{self.title}")
print("-" * self.columns)
def __exit__(self, _exc_type, _exc_value, _exc_tb):
print("-" * self.columns)
print()
class EmptyCm:
def __enter__(self):
pass
def __exit__(self, _exc_type, _exc_value, _exc_tb):
pass
if activate:
self.begin_call()
return DebugTableCm(title)
return EmptyCm()
def show_graph(self) -> bool:
"""
Tell if the configuration involves showing graph.
"""
return (
self.configuration.show_graph
if self.configuration.show_graph is not None
else self.configuration.verbose
)
def show_bit_width_constraints(self) -> bool:
"""
Tell if the configuration involves showing bitwidth constraints.
"""
return (
self.configuration.show_bit_width_constraints
if self.configuration.show_bit_width_constraints is not None
else self.configuration.verbose
)
def show_bit_width_assignments(self) -> bool:
"""
Tell if the configuration involves showing bitwidth assignments.
"""
return (
self.configuration.show_bit_width_assignments
if self.configuration.show_bit_width_assignments is not None
else self.configuration.verbose
)
def show_assigned_graph(self) -> bool:
"""
Tell if the configuration involves showing assigned graph.
"""
return (
self.configuration.show_assigned_graph
if self.configuration.show_assigned_graph is not None
else self.configuration.verbose
)
def show_mlir(self) -> bool:
"""
Tell if the configuration involves showing mlir.
"""
return (
self.configuration.show_mlir
if self.configuration.show_mlir is not None
else self.configuration.verbose
)
def show_optimizer(self) -> bool:
"""
Tell if the configuration involves showing optimizer.
"""
return (
self.configuration.show_optimizer
if self.configuration.show_optimizer is not None
else self.configuration.verbose
)
def show_statistics(self) -> bool:
"""
Tell if the configuration involves showing statistics.
"""
return (
self.configuration.show_statistics
if self.configuration.show_statistics is not None
else self.configuration.verbose
)
def debug_computation_graph(self, name, function_graph):
"""
Print computation graph if configuration tells so.
"""
if (
self.show_graph()
or self.show_bit_width_constraints()
or self.show_bit_width_assignments()
or self.show_assigned_graph()
or self.show_mlir()
or self.show_optimizer()
or self.show_statistics()
):
if self.show_graph():
with self.debug_table(f"Computation Graph for {name}"):
print(function_graph.format())
def debug_bit_width_constaints(self, name, function_graph):
"""
Print bitwidth constraints if configuration tells so.
"""
if self.show_bit_width_constraints():
with self.debug_table(f"Bit-Width Constraints for {name}"):
print(function_graph.format_bit_width_constraints())
def debug_bit_width_assignments(self, name, function_graph):
"""
Print bitwidth assignments if configuration tells so.
"""
if self.show_bit_width_assignments():
with self.debug_table(f"Bit-Width Assignments for {name}"):
print(function_graph.format_bit_width_assignments())
def debug_assigned_graph(self, name, function_graph):
"""
Print assigned graphs if configuration tells so.
"""
if self.show_assigned_graph():
with self.debug_table(f"Bit-Width Assigned Computation Graph for {name}"):
print(function_graph.format(show_assigned_bit_widths=True))
def debug_mlir(self, mlir_str):
"""
Print mlir if configuration tells so.
"""
if self.show_mlir():
with self.debug_table("MLIR"):
print(mlir_str)
def debug_statistics(self, module):
"""
Print statistics if configuration tells so.
"""
if self.show_statistics():
def pretty(d, indent=0): # pragma: no cover
if indent > 0:
print("{")
for key, value in d.items():
if isinstance(value, dict) and len(value) == 0:
continue
print(" " * indent + str(key) + ": ", end="")
if isinstance(value, dict):
pretty(value, indent + 1)
else:
print(value)
if indent > 0:
print(" " * (indent - 1) + "}")
with self.debug_table("Statistics"):
pretty(module.statistics)
class TracedOutput(NamedTuple):
"""
A wrapper type used to trace wiring.
Allows to tag an output value coming from an other module function, and binds it with
information about its origin.
"""
output_info: Output
returned_value: Any
class WireTracingContextManager:
"""
A context manager returned by the `wire_pipeline` method.
Activates wire tracing and yields an inputset that can be iterated on for tracing.
"""
def __init__(self, module, inputset):
self.module = module
self.inputset = inputset
def __enter__(self):
for func in self.module.functions.values():
func._trace_wires = self.module.composition.wires
return self.inputset
def __exit__(self, _exc_type, _exc_value, _exc_tb):
for func in self.module.functions.values():
func._trace_wires = None
class ModuleCompiler:
"""
Compiler class for multiple functions, to glue the compilation pipeline.
@@ -766,7 +355,9 @@ class ModuleCompiler:
def compile(
self,
inputsets: Optional[Dict[str, Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]] = None,
inputsets: Optional[
Dict[str, Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]]
] = None,
configuration: Optional[Configuration] = None,
module_artifacts: Optional[ModuleDebugArtifacts] = None,
**kwargs,
@@ -793,7 +384,6 @@ class ModuleCompiler:
"""
configuration = configuration if configuration is not None else self.default_configuration
configuration = deepcopy(configuration)
if len(kwargs) != 0:
configuration = configuration.fork(**kwargs)

View File

@@ -401,7 +401,7 @@ class Server:
self,
*args: Optional[Union[Value, Tuple[Optional[Value], ...]]],
evaluation_keys: Optional[EvaluationKeys] = None,
function_name: str = "main",
function_name: Optional[str] = None,
) -> Union[Value, Tuple[Value, ...]]:
"""
Evaluate.
@@ -421,6 +421,15 @@ class Server:
result(s) of evaluation
"""
if function_name is None:
functions = self.client_specs.client_parameters.function_list()
if len(functions) == 1:
function_name = functions[0]
else: # pragma: no cover
msg = "The client contains more than one functions. \
Provide a `function_name` keyword argument to disambiguate."
raise TypeError(msg)
if evaluation_keys is None and not self.is_simulated:
message = "Expected evaluation keys to be provided when not in simulation mode"
raise RuntimeError(message)
@@ -527,19 +536,19 @@ class Server:
"""
return self._compilation_feedback.complexity
def memory_usage_per_location(self, function: str = "main") -> Dict[str, int]:
def memory_usage_per_location(self, function: str) -> Dict[str, int]:
"""
Get the memory usage of operations per location.
"""
return self._compilation_feedback.circuit(function).memory_usage_per_location
def size_of_inputs(self, function: str = "main") -> int:
def size_of_inputs(self, function: str) -> int:
"""
Get size of the inputs of the compiled program.
"""
return self._compilation_feedback.circuit(function).total_inputs_size
def size_of_outputs(self, function: str = "main") -> int:
def size_of_outputs(self, function: str) -> int:
"""
Get size of the outputs of the compiled program.
"""
@@ -547,7 +556,7 @@ class Server:
# Programmable Bootstrap Statistics
def programmable_bootstrap_count(self, function: str = "main") -> int:
def programmable_bootstrap_count(self, function: str) -> int:
"""
Get the number of programmable bootstraps in the compiled program.
"""
@@ -555,9 +564,7 @@ class Server:
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
)
def programmable_bootstrap_count_per_parameter(
self, function: str = "main"
) -> Dict[Parameter, int]:
def programmable_bootstrap_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
"""
Get the number of programmable bootstraps per parameter in the compiled program.
"""
@@ -567,7 +574,7 @@ class Server:
client_parameters=self.client_specs.client_parameters,
)
def programmable_bootstrap_count_per_tag(self, function: str = "main") -> Dict[str, int]:
def programmable_bootstrap_count_per_tag(self, function: str) -> Dict[str, int]:
"""
Get the number of programmable bootstraps per tag in the compiled program.
"""
@@ -576,7 +583,7 @@ class Server:
)
def programmable_bootstrap_count_per_tag_per_parameter(
self, function: str = "main"
self, function: str
) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of programmable bootstraps per tag per parameter in the compiled program.
@@ -589,7 +596,7 @@ class Server:
# Key Switch Statistics
def key_switch_count(self, function: str = "main") -> int:
def key_switch_count(self, function: str) -> int:
"""
Get the number of key switches in the compiled program.
"""
@@ -597,7 +604,7 @@ class Server:
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
)
def key_switch_count_per_parameter(self, function: str = "main") -> Dict[Parameter, int]:
def key_switch_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
"""
Get the number of key switches per parameter in the compiled program.
"""
@@ -607,7 +614,7 @@ class Server:
client_parameters=self.client_specs.client_parameters,
)
def key_switch_count_per_tag(self, function: str = "main") -> Dict[str, int]:
def key_switch_count_per_tag(self, function: str) -> Dict[str, int]:
"""
Get the number of key switches per tag in the compiled program.
"""
@@ -616,7 +623,7 @@ class Server:
)
def key_switch_count_per_tag_per_parameter(
self, function: str = "main"
self, function: str
) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of key switches per tag per parameter in the compiled program.
@@ -629,7 +636,7 @@ class Server:
# Packing Key Switch Statistics
def packing_key_switch_count(self, function: str = "main") -> int:
def packing_key_switch_count(self, function: str) -> int:
"""
Get the number of packing key switches in the compiled program.
"""
@@ -637,9 +644,7 @@ class Server:
operations={PrimitiveOperation.WOP_PBS}
)
def packing_key_switch_count_per_parameter(
self, function: str = "main"
) -> Dict[Parameter, int]:
def packing_key_switch_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
"""
Get the number of packing key switches per parameter in the compiled program.
"""
@@ -649,7 +654,7 @@ class Server:
client_parameters=self.client_specs.client_parameters,
)
def packing_key_switch_count_per_tag(self, function: str = "main") -> Dict[str, int]:
def packing_key_switch_count_per_tag(self, function: str) -> Dict[str, int]:
"""
Get the number of packing key switches per tag in the compiled program.
"""
@@ -658,7 +663,7 @@ class Server:
)
def packing_key_switch_count_per_tag_per_parameter(
self, function: str = "main"
self, function: str
) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of packing key switches per tag per parameter in the compiled program.
@@ -671,7 +676,7 @@ class Server:
# Clear Addition Statistics
def clear_addition_count(self, function: str = "main") -> int:
def clear_addition_count(self, function: str) -> int:
"""
Get the number of clear additions in the compiled program.
"""
@@ -679,7 +684,7 @@ class Server:
operations={PrimitiveOperation.CLEAR_ADDITION}
)
def clear_addition_count_per_parameter(self, function: str = "main") -> Dict[Parameter, int]:
def clear_addition_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
"""
Get the number of clear additions per parameter in the compiled program.
"""
@@ -689,7 +694,7 @@ class Server:
client_parameters=self.client_specs.client_parameters,
)
def clear_addition_count_per_tag(self, function: str = "main") -> Dict[str, int]:
def clear_addition_count_per_tag(self, function: str) -> Dict[str, int]:
"""
Get the number of clear additions per tag in the compiled program.
"""
@@ -698,7 +703,7 @@ class Server:
)
def clear_addition_count_per_tag_per_parameter(
self, function: str = "main"
self, function: str
) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of clear additions per tag per parameter in the compiled program.
@@ -711,7 +716,7 @@ class Server:
# Encrypted Addition Statistics
def encrypted_addition_count(self, function: str = "main") -> int:
def encrypted_addition_count(self, function: str) -> int:
"""
Get the number of encrypted additions in the compiled program.
"""
@@ -719,9 +724,7 @@ class Server:
operations={PrimitiveOperation.ENCRYPTED_ADDITION}
)
def encrypted_addition_count_per_parameter(
self, function: str = "main"
) -> Dict[Parameter, int]:
def encrypted_addition_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
"""
Get the number of encrypted additions per parameter in the compiled program.
"""
@@ -731,7 +734,7 @@ class Server:
client_parameters=self.client_specs.client_parameters,
)
def encrypted_addition_count_per_tag(self, function: str = "main") -> Dict[str, int]:
def encrypted_addition_count_per_tag(self, function: str) -> Dict[str, int]:
"""
Get the number of encrypted additions per tag in the compiled program.
"""
@@ -740,7 +743,7 @@ class Server:
)
def encrypted_addition_count_per_tag_per_parameter(
self, function: str = "main"
self, function: str
) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of encrypted additions per tag per parameter in the compiled program.
@@ -753,7 +756,7 @@ class Server:
# Clear Multiplication Statistics
def clear_multiplication_count(self, function: str = "main") -> int:
def clear_multiplication_count(self, function: str) -> int:
"""
Get the number of clear multiplications in the compiled program.
"""
@@ -761,9 +764,7 @@ class Server:
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
)
def clear_multiplication_count_per_parameter(
self, function: str = "main"
) -> Dict[Parameter, int]:
def clear_multiplication_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
"""
Get the number of clear multiplications per parameter in the compiled program.
"""
@@ -773,7 +774,7 @@ class Server:
client_parameters=self.client_specs.client_parameters,
)
def clear_multiplication_count_per_tag(self, function: str = "main") -> Dict[str, int]:
def clear_multiplication_count_per_tag(self, function: str) -> Dict[str, int]:
"""
Get the number of clear multiplications per tag in the compiled program.
"""
@@ -782,7 +783,7 @@ class Server:
)
def clear_multiplication_count_per_tag_per_parameter(
self, function: str = "main"
self, function: str
) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of clear multiplications per tag per parameter in the compiled program.
@@ -795,7 +796,7 @@ class Server:
# Encrypted Negation Statistics
def encrypted_negation_count(self, function: str = "main") -> int:
def encrypted_negation_count(self, function: str) -> int:
"""
Get the number of encrypted negations in the compiled program.
"""
@@ -803,9 +804,7 @@ class Server:
operations={PrimitiveOperation.ENCRYPTED_NEGATION}
)
def encrypted_negation_count_per_parameter(
self, function: str = "main"
) -> Dict[Parameter, int]:
def encrypted_negation_count_per_parameter(self, function: str) -> Dict[Parameter, int]:
"""
Get the number of encrypted negations per parameter in the compiled program.
"""
@@ -815,7 +814,7 @@ class Server:
client_parameters=self.client_specs.client_parameters,
)
def encrypted_negation_count_per_tag(self, function: str = "main") -> Dict[str, int]:
def encrypted_negation_count_per_tag(self, function: str) -> Dict[str, int]:
"""
Get the number of encrypted negations per tag in the compiled program.
"""
@@ -824,7 +823,7 @@ class Server:
)
def encrypted_negation_count_per_tag_per_parameter(
self, function: str = "main"
self, function: str
) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of encrypted negations per tag per parameter in the compiled program.
@@ -834,52 +833,3 @@ class Server:
key_types={KeyType.SECRET},
client_parameters=self.client_specs.client_parameters,
)
# All Statistics
@property
def statistics(self) -> Dict:
"""
Get all statistics of the compiled program.
"""
attributes = [
"size_of_inputs",
"size_of_outputs",
"programmable_bootstrap_count",
"programmable_bootstrap_count_per_parameter",
"programmable_bootstrap_count_per_tag",
"programmable_bootstrap_count_per_tag_per_parameter",
"key_switch_count",
"key_switch_count_per_parameter",
"key_switch_count_per_tag",
"key_switch_count_per_tag_per_parameter",
"packing_key_switch_count",
"packing_key_switch_count_per_parameter",
"packing_key_switch_count_per_tag",
"packing_key_switch_count_per_tag_per_parameter",
"clear_addition_count",
"clear_addition_count_per_parameter",
"clear_addition_count_per_tag",
"clear_addition_count_per_tag_per_parameter",
"encrypted_addition_count",
"encrypted_addition_count_per_parameter",
"encrypted_addition_count_per_tag",
"encrypted_addition_count_per_tag_per_parameter",
"clear_multiplication_count",
"clear_multiplication_count_per_parameter",
"clear_multiplication_count_per_tag",
"clear_multiplication_count_per_tag_per_parameter",
"encrypted_negation_count",
"encrypted_negation_count_per_parameter",
"encrypted_negation_count_per_tag",
"encrypted_negation_count_per_tag_per_parameter",
"memory_usage_per_location",
]
output = {attribute: getattr(self, attribute)() for attribute in attributes}
output["size_of_secret_keys"] = self.size_of_secret_keys
output["size_of_bootstrap_keys"] = self.size_of_bootstrap_keys
output["size_of_keyswitch_keys"] = self.size_of_keyswitch_keys
output["p_error"] = self.p_error
output["global_p_error"] = self.global_p_error
output["complexity"] = self.complexity
return output

View File

@@ -0,0 +1,17 @@
"""
Declaration of `EncryptionStatus` class.
"""
# pylint: disable=import-error,no-name-in-module
from enum import Enum, unique
@unique
class EncryptionStatus(str, Enum):
"""
EncryptionStatus enum, to represent encryption status of parameters.
"""
CLEAR = "clear"
ENCRYPTED = "encrypted"

View File

@@ -7,6 +7,7 @@ import os
import re
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
@@ -27,9 +28,11 @@ from ..dtypes import Float, Integer, SignedInteger, UnsignedInteger
from ..representation import Graph, Node, Operation
from ..tracing import ScalarAnnotation
from ..values import ValueDescription
from .artifacts import FunctionDebugArtifacts
from .specs import ClientSpecs
if TYPE_CHECKING:
from .artifacts import FunctionDebugArtifacts # pragma: no cover
# ruff: noqa: ERA001
T = TypeVar("T")
@@ -118,7 +121,7 @@ def inputset(
def validate_input_args(
client_specs: ClientSpecs,
*args: Optional[Union[int, np.ndarray, List]],
function_name: str = "main",
function_name: str,
) -> List[Optional[Union[int, np.ndarray]]]:
"""Validate input arguments.
@@ -214,7 +217,7 @@ def validate_input_args(
return ordered_sanitized_args
def fuse(graph: Graph, artifacts: Optional[FunctionDebugArtifacts] = None):
def fuse(graph: Graph, artifacts: Optional["FunctionDebugArtifacts"] = None):
"""
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
@@ -817,7 +820,7 @@ def convert_subgraph_to_subgraph_node(
original_tag = terminal_node.tag
original_created_at = terminal_node.created_at
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node}, graph.name)
subgraph_node = Node.generic(
"subgraph",
deepcopy(subgraph_variable_input_node.inputs),

View File

@@ -0,0 +1,235 @@
"""
Declaration of wiring related class.
"""
# pylint: disable=import-error,no-name-in-module
from itertools import chain, product, repeat
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
NamedTuple,
Optional,
Protocol,
Set,
runtime_checkable,
)
from ..representation import Graph
from .composition import CompositionClause, CompositionRule
if TYPE_CHECKING:
from .module_compiler import FunctionDef # pragma: no cover
class NotComposable:
"""
Composition policy that does not allow the forwarding of any output to any input.
"""
def get_rules_iter(self, _funcs: List["FunctionDef"]) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
return [] # pragma: no cover
class AllComposable:
"""
Composition policy that allows to forward any output of the module to any of its input.
"""
def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
outputs = []
for f in funcs:
for pos, node in f.output_nodes.items():
if node.output.is_encrypted:
outputs.append(CompositionClause.create((f.name, pos)))
inputs = []
for f in funcs:
for pos, node in f.input_nodes.items():
if node.output.is_encrypted:
inputs.append(CompositionClause.create((f.name, pos)))
return map(CompositionRule.create, product(outputs, inputs))
@runtime_checkable
class WireOutput(Protocol):
"""
A protocol for wire outputs.
"""
def get_outputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible outputs of the wire output.
"""
@runtime_checkable
class WireInput(Protocol):
"""
A protocol for wire inputs.
"""
def get_inputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible inputs of the wire input.
"""
class Output(NamedTuple):
"""
The output of a given function of a module.
"""
func: "FunctionDef"
pos: int
def get_outputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible outputs of the wire output.
"""
return [CompositionClause(self.func.name, self.pos)]
class AllOutputs(NamedTuple):
"""
All the encrypted outputs of a given function of a module.
"""
func: "FunctionDef"
def get_outputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible outputs of the wire output.
"""
assert self.func.graph # pragma: no cover
# No need to filter since only encrypted outputs are valid.
return map( # pragma: no cover
CompositionClause.create,
zip(repeat(self.func.name), range(self.func.graph.outputs_count)),
)
class Input(NamedTuple):
"""
The input of a given function of a module.
"""
func: "FunctionDef"
pos: int
def get_inputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible inputs of the wire input.
"""
return [CompositionClause(self.func.name, self.pos)]
class AllInputs(NamedTuple):
"""
All the encrypted inputs of a given function of a module.
"""
func: "FunctionDef"
def get_inputs_iter(self) -> Iterable[CompositionClause]:
"""
Return an iterator over the possible inputs of the wire input.
"""
assert self.func.graph # pragma: no cover
output = []
for i in range(self.func.graph.inputs_count):
if self.func.graph.input_nodes[i].output.is_encrypted:
output.append(CompositionClause.create((self.func.name, i)))
return output
class Wire(NamedTuple):
"""
A forwarding rule between an output and an input.
"""
output: WireOutput
input: WireInput
def get_rules_iter(self, _) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
return map(
CompositionRule.create,
product(self.output.get_outputs_iter(), self.input.get_inputs_iter()),
)
class Wired:
"""
Composition policy which allows the forwarding of certain outputs to certain inputs.
"""
wires: Set[Wire]
def __init__(self, wires: Optional[Set[Wire]] = None):
self.wires = wires if wires else set()
def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
funcsd = {f.name: f for f in funcs}
rules = list(chain(*[w.get_rules_iter(funcs) for w in self.wires]))
# We check that the given rules are legit (they concern only encrypted values)
for rule in rules:
if (
not funcsd[rule.from_.func].output_nodes[rule.from_.pos].output.is_encrypted
): # pragma: no cover
message = f"Invalid composition rule encountered: \
Output {rule.from_.pos} of {rule.from_.func} is not encrypted"
raise RuntimeError(message)
if not funcsd[rule.to.func].input_nodes[rule.to.pos].output.is_encrypted:
message = f"Invalid composition rule encountered: \
Input {rule.from_.pos} of {rule.from_.func} is not encrypted"
raise RuntimeError(message)
return rules
class TracedOutput(NamedTuple):
"""
A wrapper type used to trace wiring.
Allows to tag an output value coming from an other module function, and binds it with
information about its origin.
"""
output_info: Output
returned_value: Any
class WireTracingContextManager:
"""
A context manager returned by the `wire_pipeline` method.
Activates wire tracing and yields an inputset that can be iterated on for tracing.
"""
def __init__(self, module, inputset):
self.module = module
self.inputset = inputset
def __enter__(self):
for func in self.module.functions.values():
func._trace_wires = self.module.composition.wires
return self.inputset
def __exit__(self, _exc_type, _exc_value, _exc_tb):
for func in self.module.functions.values():
func._trace_wires = None

View File

@@ -145,7 +145,6 @@ class Converter:
self,
graph: Graph,
mlir_context: MlirContext,
name: str = "main",
) -> MlirModule:
"""
Convert a computation graph to MLIR.
@@ -157,15 +156,13 @@ class Converter:
mlir_context (MlirContext):
MLIR Context to use for module generation
name (str):
name of the function to convert
Return:
MlirModule:
In-memory MLIR module corresponding to the graph
"""
return self.convert_many({name: graph}, mlir_context)
return self.convert_many({graph.name: graph}, mlir_context)
@staticmethod
def stdout_with_ansi_support() -> bool:

View File

@@ -49,8 +49,8 @@ class Graph:
graph: nx.MultiDiGraph,
input_nodes: Dict[int, Node],
output_nodes: Dict[int, Node],
name: str,
is_direct: bool = False,
name: str = "main",
location: str = "",
):
self.graph = graph
@@ -1057,4 +1057,4 @@ class MultiGraphProcessor(GraphProcessor):
"""
Process a single graph.
"""
return self.apply_many({"main": graph}) # pragma: no cover
return self.apply_many({graph.name: graph}) # pragma: no cover

View File

@@ -187,7 +187,7 @@ def new_bridge(
circuit: "fhe.Circuit",
input_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
output_types: Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]],
func_name: str = "main",
func_name,
) -> Bridge:
"""Create a TFHErs bridge from a circuit.
@@ -199,7 +199,7 @@ def new_bridge(
output_types (Union[List[Optional[TFHERSIntegerType]], Optional[TFHERSIntegerType]]): lists
should map every output to a type, while a single element is general for all outputs.
None means a non-tfhers type
func_name (str, optional): name of the function to use. Defaults to "main".
func_name (str): name of the function to use.
Returns:
Bridge: TFHErs bridge

View File

@@ -38,7 +38,6 @@ class Tracer:
function: Callable,
parameters: Dict[str, ValueDescription],
is_direct: bool = False,
name: str = "main",
location: str = "",
) -> Graph:
"""
@@ -55,9 +54,6 @@ class Tracer:
is_direct (bool, default = False):
whether the tracing is done on actual parameters or placeholders
name (str, default = "main"):
the name of the function being traced
Returns:
Graph:
computation graph corresponding to `function`
@@ -169,7 +165,9 @@ class Tracer:
output_idx: tracer.computation for output_idx, tracer in enumerate(output_tracers)
}
return Graph(graph, input_nodes, output_nodes, is_direct, name, location=location)
return Graph(
graph, input_nodes, output_nodes, function.__name__, is_direct, location=location
)
# pylint: enable=too-many-statements

View File

@@ -35,13 +35,13 @@ def test_artifacts_export(helpers):
assert (tmpdir / "environment.txt").exists()
assert (tmpdir / "requirements.txt").exists()
assert (tmpdir / "main.txt").exists()
assert (tmpdir / "main.parameters.txt").exists()
assert (tmpdir / "f.txt").exists()
assert (tmpdir / "f.parameters.txt").exists()
assert (tmpdir / "main.1.initial.graph.txt").exists()
assert (tmpdir / "main.2.after-fusing.graph.txt").exists()
assert (tmpdir / "main.3.after-fusing.graph.txt").exists()
assert (tmpdir / "main.4.final.graph.txt").exists()
assert (tmpdir / "f.1.initial.graph.txt").exists()
assert (tmpdir / "f.2.after-fusing.graph.txt").exists()
assert (tmpdir / "f.3.after-fusing.graph.txt").exists()
assert (tmpdir / "f.4.final.graph.txt").exists()
assert (tmpdir / "mlir.txt").exists()
assert (tmpdir / "client_parameters.json").exists()
@@ -51,13 +51,13 @@ def test_artifacts_export(helpers):
assert (tmpdir / "environment.txt").exists()
assert (tmpdir / "requirements.txt").exists()
assert (tmpdir / "main.txt").exists()
assert (tmpdir / "main.parameters.txt").exists()
assert (tmpdir / "f.txt").exists()
assert (tmpdir / "f.parameters.txt").exists()
assert (tmpdir / "main.1.initial.graph.txt").exists()
assert (tmpdir / "main.2.after-fusing.graph.txt").exists()
assert (tmpdir / "main.3.after-fusing.graph.txt").exists()
assert (tmpdir / "main.4.final.graph.txt").exists()
assert (tmpdir / "f.1.initial.graph.txt").exists()
assert (tmpdir / "f.2.after-fusing.graph.txt").exists()
assert (tmpdir / "f.3.after-fusing.graph.txt").exists()
assert (tmpdir / "f.4.final.graph.txt").exists()
assert (tmpdir / "mlir.txt").exists()
assert (tmpdir / "client_parameters.json").exists()

View File

@@ -7,6 +7,8 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.compiler import CompilationContext
from mlir.ir import Module as MlirModule
from concrete import fhe
from concrete.fhe import Client, ClientSpecs, EvaluationKeys, LookupTable, Server, Value
@@ -89,6 +91,8 @@ def test_circuit_feedback(helpers):
assert isinstance(circuit.size_of_outputs, int)
assert isinstance(circuit.p_error, float)
assert isinstance(circuit.global_p_error, float)
assert isinstance(circuit.mlir_module, MlirModule)
assert isinstance(circuit.compilation_context, CompilationContext)
assert isinstance(circuit.memory_usage_per_location, dict)
assert all(
@@ -836,5 +840,4 @@ def test_simulate_encrypt_run_decrypt(helpers):
assert isinstance(encrypted_x, int)
assert isinstance(encrypted_y, int)
assert hasattr(circuit, "simulator")
assert not hasattr(circuit, "server")
assert isinstance(encrypted_result, int)

View File

@@ -450,7 +450,7 @@ def test_compiler_reset(helpers):
"""
module {
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
func.func @f(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
return %0 : !FHE.eint<4>
}
@@ -468,7 +468,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
func.func @f(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<11>, !FHE.eint<11>) -> !FHE.eint<11>
return %0 : !FHE.eint<11>
}
@@ -486,7 +486,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
func.func @f(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
%0 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<3>>, tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>>
return %0 : tensor<3x2x!FHE.eint<3>>
}

View File

@@ -222,39 +222,39 @@ def test_configuration_bad_fork(kwargs, expected_error, expected_message):
"""
%0:
main.%0 >= 3
<lambda>.%0 >= 3
%1:
main.%1 >= 7
<lambda>.%1 >= 7
%2:
main.%2 >= 2
<lambda>.%2 >= 2
%3:
main.%3 >= 5
<lambda>.%3 >= 5
%4:
main.%4 >= 8
main.%3 == main.%1
main.%1 == main.%4
<lambda>.%4 >= 8
<lambda>.%3 == <lambda>.%1
<lambda>.%1 == <lambda>.%4
""",
(
"""
main.%0 = 3
main.%1 = 8
main.%2 = 2
main.%3 = 8
main.%4 = 8
main.max = 8
<lambda>.%0 = 3
<lambda>.%1 = 8
<lambda>.%2 = 2
<lambda>.%3 = 8
<lambda>.%4 = 8
<lambda>.max = 8
"""
if USE_MULTI_PRECISION
else """
main.%0 = 8
main.%1 = 8
main.%2 = 8
main.%3 = 8
main.%4 = 8
main.max = 8
<lambda>.%0 = 8
<lambda>.%1 = 8
<lambda>.%2 = 8
<lambda>.%3 = 8
<lambda>.%4 = 8
<lambda>.max = 8
"""
),
@@ -275,7 +275,7 @@ def test_configuration_show_bit_width_constraints_and_assignment(
Test compiling with configuration where show_bit_width_(constraints/assignments)=True.
"""
monkeypatch.setattr("concrete.fhe.compilation.compiler.get_terminal_size", lambda: 80)
monkeypatch.setattr("concrete.fhe.compilation.artifacts.get_terminal_size", lambda: 80)
configuration = helpers.configuration()
compiler = fhe.Compiler(function, encryption_status)
@@ -289,12 +289,12 @@ def test_configuration_show_bit_width_constraints_and_assignment(
captured.out.strip(),
f"""
Bit-Width Constraints
Bit-Width Constraints for <lambda>
--------------------------------------------------------------------------------
{expected_bit_width_constraints.lstrip(os.linesep).rstrip()}
--------------------------------------------------------------------------------
Bit-Width Assignments
Bit-Width Assignments for <lambda>
--------------------------------------------------------------------------------
{expected_bit_width_assignment.lstrip(os.linesep).rstrip()}
--------------------------------------------------------------------------------

View File

@@ -28,42 +28,12 @@ def test_compiler_call_and_compile(helpers):
helpers.check_execution(circuit, function, sample)
def test_compiler_verbose_trace(helpers, capsys, monkeypatch):
"""
Test `trace` method of `compiler` decorator with verbose flag.
"""
monkeypatch.setattr("concrete.fhe.compilation.compiler.get_terminal_size", lambda: 80)
configuration = helpers.configuration()
artifacts = fhe.DebugArtifacts()
@fhe.compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
graph = function.trace(inputset, configuration, artifacts, show_graph=True)
captured = capsys.readouterr()
assert captured.out.strip() == (
f"""
Computation Graph
------------------------------------------------------------------
{graph.format()}
------------------------------------------------------------------
""".strip()
)
def test_compiler_verbose_compile(helpers, capsys, monkeypatch):
"""
Test `compile` method of `compiler` decorator with verbose flag.
"""
monkeypatch.setattr("concrete.fhe.compilation.compiler.get_terminal_size", lambda: 80)
monkeypatch.setattr("concrete.fhe.compilation.artifacts.get_terminal_size", lambda: 80)
configuration = helpers.configuration()
artifacts = fhe.DebugArtifacts()
@@ -76,44 +46,46 @@ def test_compiler_verbose_compile(helpers, capsys, monkeypatch):
circuit = function.compile(inputset, configuration, artifacts, verbose=True)
captured = capsys.readouterr()
assert captured.out.strip().startswith(
f"""
Computation Graph
Computation Graph for function
--------------------------------------------------------------------------------
{circuit.graph.format()}
--------------------------------------------------------------------------------
Bit-Width Constraints
--------------------------------------------------------------------------------
%0:
main.%0 >= 4
%1:
main.%1 >= 6
%2:
main.%2 >= 6
main.%0 == main.%1
main.%1 == main.%2
--------------------------------------------------------------------------------
Bit-Width Assignments
--------------------------------------------------------------------------------
main.%0 = 6
main.%1 = 6
main.%2 = 6
main.max = 6
--------------------------------------------------------------------------------
Bit-Width Assigned Computation Graph
--------------------------------------------------------------------------------
{circuit.graph.format(show_assigned_bit_widths=True)}
--------------------------------------------------------------------------------
MLIR
--------------------------------------------------------------------------------
{artifacts.mlir_to_compile}
--------------------------------------------------------------------------------
Bit-Width Constraints for function
--------------------------------------------------------------------------------
%0:
function.%0 >= 4
%1:
function.%1 >= 6
%2:
function.%2 >= 6
function.%0 == function.%1
function.%1 == function.%2
--------------------------------------------------------------------------------
Bit-Width Assignments for function
--------------------------------------------------------------------------------
function.%0 = 6
function.%1 = 6
function.%2 = 6
function.max = 6
--------------------------------------------------------------------------------
Bit-Width Assigned Computation Graph for function
--------------------------------------------------------------------------------
{circuit.graph.format(show_assigned_bit_widths=True)}
--------------------------------------------------------------------------------
Optimizer
--------------------------------------------------------------------------------
@@ -325,7 +297,7 @@ def test_compiler_reset(helpers):
"""
module {
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
func.func @compiler(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
return %0 : !FHE.eint<4>
}
@@ -343,7 +315,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
func.func @compiler(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<11>, !FHE.eint<11>) -> !FHE.eint<11>
return %0 : !FHE.eint<11>
}
@@ -361,7 +333,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
func.func @compiler(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
%0 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<3>>, tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>>
return %0 : tensor<3x2x!FHE.eint<3>>
}

View File

@@ -263,7 +263,7 @@ def test_round_bit_pattern_no_overflow_protection(helpers):
"""
module {
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.eint<11> {
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.eint<11> {
%0 = "FHE.round"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<5>
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 16, 64, 144, 256, 400, 576, 784, 1024, 1296, 1600, 1936, 2304, 2704, 3136, 3600, 4096, 3600, 3136, 2704, 2304, 1936, 1600, 1296, 1024, 784, 576, 400, 256, 144, 64, 16]> : tensor<32xi64>
@@ -277,7 +277,7 @@ module {
else """
module {
func.func @main(%arg0: !FHE.esint<11>) -> !FHE.eint<11> {
func.func @function(%arg0: !FHE.esint<11>) -> !FHE.eint<11> {
%c16_i12 = arith.constant 16 : i12
%0 = "FHE.mul_eint_int"(%arg0, %c16_i12) : (!FHE.esint<11>, i12) -> !FHE.esint<11>
%1 = "FHE.round"(%0) : (!FHE.esint<11>) -> !FHE.esint<5>
@@ -312,7 +312,7 @@ def test_round_bit_pattern_identity(helpers):
"""
module {
func.func @main(%arg0: !FHE.esint<6>) -> !FHE.esint<7> {
func.func @function(%arg0: !FHE.esint<6>) -> !FHE.esint<7> {
%0 = "FHE.round"(%arg0) : (!FHE.esint<6>) -> !FHE.esint<4>
%cst = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, -32, -28, -24, -20, -16, -12, -8, -4]> : tensor<16xi64>
%1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.esint<4>, tensor<16xi64>) -> !FHE.esint<7>
@@ -326,7 +326,7 @@ module {
else """
module {
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
%c2_i8 = arith.constant 2 : i8
%0 = "FHE.mul_eint_int"(%arg0, %c2_i8) : (!FHE.esint<7>, i8) -> !FHE.esint<7>
%1 = "FHE.round"(%0) : (!FHE.esint<7>) -> !FHE.esint<4>

View File

@@ -382,7 +382,7 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen(
assert (dtype.decode(concrete_encoded_result) == function(*sample)).all()
###### TFHErs Encryption ######################################################
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main")
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="<lambda>")
# serialize key
_, key_path = tempfile.mkstemp()
@@ -577,7 +577,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen(
assert (dtype.decode(concrete_encoded_result) == function(*sample)).all()
###### TFHErs Encryption ######################################################
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="main")
tfhers_bridge = tfhers.new_bridge(circuit, dtype, dtype, func_name="<lambda>")
# serialize key
_, key_path = tempfile.mkstemp()

View File

@@ -250,7 +250,7 @@ def test_truncate_bit_pattern_identity(helpers, pytestconfig):
"""
module {
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
%0 = "FHE.lsb"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<7>
%1 = "FHE.sub_eint"(%arg0, %0) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
%2 = "FHE.reinterpret_precision"(%1) : (!FHE.esint<7>) -> !FHE.esint<6>
@@ -267,7 +267,7 @@ module {
else """
module {
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
%0 = "FHE.lsb"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<7>
%1 = "FHE.sub_eint"(%arg0, %0) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
%2 = "FHE.reinterpret_precision"(%1) : (!FHE.esint<7>) -> !FHE.esint<6>

View File

@@ -1143,7 +1143,7 @@ def test_converter_bad_convert(
"""
module {
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<6> {
func.func @"<lambda>"(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<6> {
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<6>
return %0 : !FHE.eint<6>
}
@@ -1161,7 +1161,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>> {
func.func @"<lambda>"(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>> {
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<3x2x!FHE.eint<4>>) -> tensor<3x2x!FHE.eint<6>>
return %0 : tensor<3x2x!FHE.eint<6>>
}
@@ -1179,7 +1179,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<4>>, %arg1: tensor<2x!FHE.eint<4>>) -> !FHE.eint<7> {
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<4>>, %arg1: tensor<2x!FHE.eint<4>>) -> !FHE.eint<7> {
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<4>>, tensor<2x!FHE.eint<4>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}
@@ -1197,7 +1197,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>> {
func.func @"<lambda>"(%arg0: tensor<3x2x!FHE.eint<4>>, %arg1: tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>> {
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<4>>, tensor<2x4x!FHE.eint<4>>) -> tensor<3x4x!FHE.eint<7>>
return %0 : tensor<3x4x!FHE.eint<7>>
}
@@ -1215,7 +1215,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<10x!FHE.eint<4>>, %arg1: tensor<10x!FHE.eint<4>>) -> !FHE.eint<1> {
func.func @"<lambda>"(%arg0: tensor<10x!FHE.eint<4>>, %arg1: tensor<10x!FHE.eint<4>>) -> !FHE.eint<1> {
%0 = "FHELinalg.to_signed"(%arg0) : (tensor<10x!FHE.eint<4>>) -> tensor<10x!FHE.esint<4>>
%1 = "FHELinalg.to_signed"(%arg1) : (tensor<10x!FHE.eint<4>>) -> tensor<10x!FHE.esint<4>>
%2 = "FHELinalg.sub_eint"(%0, %1) : (tensor<10x!FHE.esint<4>>, tensor<10x!FHE.esint<4>>) -> tensor<10x!FHE.esint<4>>
@@ -1242,7 +1242,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<2>>, %arg1: !FHE.eint<6>) -> tensor<2x!FHE.eint<6>> {
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<2>>, %arg1: !FHE.eint<6>) -> tensor<2x!FHE.eint<6>> {
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
%0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<6>>
@@ -1263,7 +1263,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
func.func @"<lambda>"(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
%c2_i3 = arith.constant 2 : i3
%c8_i8 = arith.constant 8 : i8
%0 = "FHE.mul_eint_int"(%arg0, %c8_i8) : (!FHE.eint<7>, i8) -> !FHE.eint<7>
@@ -1289,7 +1289,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
func.func @"<lambda>"(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
%c2_i3 = arith.constant 2 : i3
%c12_i8 = arith.constant 12 : i8
%0 = "FHE.sub_eint_int"(%arg0, %c12_i8) : (!FHE.eint<7>, i8) -> !FHE.eint<7>
@@ -1317,7 +1317,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<3>> {
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<3>> {
%c12_i5 = arith.constant 12 : i5
%from_elements = tensor.from_elements %c12_i5 : tensor<1xi5>
%0 = "FHELinalg.sub_eint_int"(%arg0, %from_elements) : (tensor<2x!FHE.eint<4>>, tensor<1xi5>) -> tensor<2x!FHE.eint<4>>
@@ -1345,7 +1345,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
func.func @"<lambda>"(%arg0: !FHE.eint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
%c2_i3 = arith.constant 2 : i3
%c4_i8 = arith.constant 4 : i8
%0 = "FHE.mul_eint_int"(%arg0, %c4_i8) : (!FHE.eint<7>, i8) -> !FHE.eint<7>
@@ -1371,7 +1371,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<3>> {
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<3>> {
%c12_i5 = arith.constant 12 : i5
%from_elements = tensor.from_elements %c12_i5 : tensor<1xi5>
%0 = "FHELinalg.sub_eint_int"(%arg0, %from_elements) : (tensor<2x!FHE.eint<4>>, tensor<1xi5>) -> tensor<2x!FHE.eint<4>>
@@ -1399,7 +1399,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<5>>) -> tensor<2x!FHE.eint<3>> {
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<5>>) -> tensor<2x!FHE.eint<3>> {
%cst = arith.constant dense<[0, 1]> : tensor<2xindex>
%cst_0 = arith.constant dense<[[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10]]> : tensor<2x32xi64>
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<2x!FHE.eint<5>>, tensor<2x32xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<3>>
@@ -1420,7 +1420,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<8>) -> (!FHE.eint<5>, !FHE.eint<8>) {
func.func @"<lambda>"(%arg0: !FHE.eint<8>) -> (!FHE.eint<5>, !FHE.eint<8>) {
%c2_i3 = arith.constant 2 : i3
%c4_i9 = arith.constant 4 : i9
%0 = "FHE.mul_eint_int"(%arg0, %c4_i9) : (!FHE.eint<8>, i9) -> !FHE.eint<8>
@@ -1444,7 +1444,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.esint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
func.func @"<lambda>"(%arg0: !FHE.esint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
%c2_i3 = arith.constant 2 : i3
%c4_i8 = arith.constant 4 : i8
%0 = "FHE.mul_eint_int"(%arg0, %c4_i8) : (!FHE.esint<7>, i8) -> !FHE.esint<7>
@@ -1471,7 +1471,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.esint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
func.func @"<lambda>"(%arg0: !FHE.esint<7>) -> (!FHE.eint<8>, !FHE.eint<7>) {
%c2_i3 = arith.constant 2 : i3
%c13_i8 = arith.constant 13 : i8
%0 = "FHE.add_eint_int"(%arg0, %c13_i8) : (!FHE.esint<7>, i8) -> !FHE.esint<7>
@@ -1500,7 +1500,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.esint<8>) -> (!FHE.esint<5>, !FHE.eint<8>) {
func.func @"<lambda>"(%arg0: !FHE.esint<8>) -> (!FHE.esint<5>, !FHE.eint<8>) {
%c2_i3 = arith.constant 2 : i3
%c4_i9 = arith.constant 4 : i9
%0 = "FHE.mul_eint_int"(%arg0, %c4_i9) : (!FHE.esint<8>, i9) -> !FHE.esint<8>
@@ -1527,7 +1527,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%c2_i7 = arith.constant 2 : i7
%0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
@@ -1552,7 +1552,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
@@ -1574,7 +1574,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%c2_i7 = arith.constant 2 : i7
%0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
@@ -1599,7 +1599,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%c2_i7 = arith.constant 2 : i7
%0 = "FHE.mul_eint_int"(%arg0, %c2_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
@@ -1624,7 +1624,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
@@ -1646,7 +1646,7 @@ module {
"""
module {
func.func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
func.func @"<lambda>"(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c3_i3 = arith.constant 3 : i3
%cst = arith.constant dense<[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>
@@ -1669,7 +1669,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i4) -> !FHE.eint<4> {
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i4) -> !FHE.eint<4> {
%0 = arith.index_cast %arg1 : i4 to index
%extracted = tensor.extract %arg0[%0] : tensor<8x!FHE.eint<4>>
return %extracted : !FHE.eint<4>
@@ -1690,7 +1690,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i5) -> !FHE.eint<4> {
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i5) -> !FHE.eint<4> {
%c8_i5 = arith.constant 8 : i5
%c0_i5 = arith.constant 0 : i5
%0 = arith.cmpi slt, %arg1, %c0_i5 : i5
@@ -1720,7 +1720,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i4) -> !FHE.eint<4> {
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i4) -> !FHE.eint<4> {
%c8_i5 = arith.constant 8 : i5
%0 = arith.extsi %arg1 : i4 to i5
%1 = arith.cmpi sge, %0, %c8_i5 : i5
@@ -1752,7 +1752,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i5) -> !FHE.eint<4> {
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: i5) -> !FHE.eint<4> {
%c8_i5 = arith.constant 8 : i5
%c0_i5 = arith.constant 0 : i5
%0 = arith.cmpi slt, %arg1, %c0_i5 : i5
@@ -1790,7 +1790,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi4>) -> tensor<3x2x!FHE.eint<4>> {
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi4>) -> tensor<3x2x!FHE.eint<4>> {
%0 = arith.index_cast %arg1 : tensor<3x2xi4> to tensor<3x2xindex>
%1 = "FHELinalg.fancy_index"(%arg0, %0) : (tensor<8x!FHE.eint<4>>, tensor<3x2xindex>) -> tensor<3x2x!FHE.eint<4>>
return %1 : tensor<3x2x!FHE.eint<4>>
@@ -1811,7 +1811,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi5>) -> tensor<3x2x!FHE.eint<4>> {
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi5>) -> tensor<3x2x!FHE.eint<4>> {
%c8_i5 = arith.constant 8 : i5
%generated = tensor.generate {
^bb0(%arg2: index, %arg3: index):
@@ -1846,7 +1846,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi4>) -> tensor<3x2x!FHE.eint<4>> {
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi4>) -> tensor<3x2x!FHE.eint<4>> {
%c8_i5 = arith.constant 8 : i5
%0 = arith.extsi %arg1 : tensor<3x2xi4> to tensor<3x2xi5>
%c0 = arith.constant 0 : index
@@ -1889,7 +1889,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi5>) -> tensor<3x2x!FHE.eint<4>> {
func.func @"<lambda>"(%arg0: tensor<8x!FHE.eint<4>>, %arg1: tensor<3x2xi5>) -> tensor<3x2x!FHE.eint<4>> {
%c8_i5 = arith.constant 8 : i5
%generated = tensor.generate {
^bb0(%arg2: index, %arg3: index):
@@ -1930,7 +1930,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<102x70x104x!FHE.eint<6>>) -> tensor<102x70x104x!FHE.eint<5>> {
func.func @"<lambda>"(%arg0: tensor<102x70x104x!FHE.eint<6>>) -> tensor<102x70x104x!FHE.eint<5>> {
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31]> : tensor<64xi64>
%0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<102x70x104x!FHE.eint<6>>, tensor<64xi64>) -> tensor<102x70x104x!FHE.eint<5>>
@@ -1949,7 +1949,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<3x!FHE.eint<8>>) -> !FHE.eint<8> {
func.func @"<lambda>"(%arg0: tensor<3x!FHE.eint<8>>) -> !FHE.eint<8> {
%cst = arith.constant dense<[1, 0, 2]> : tensor<3xi3>
%0 = "FHELinalg.dot_eint_int"(%arg0, %cst) : (tensor<3x!FHE.eint<8>>, tensor<3xi3>) -> !FHE.eint<8>
return %0 : !FHE.eint<8>
@@ -1967,7 +1967,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<2x2x!FHE.eint<8>>) -> tensor<2x2x!FHE.eint<8>> {
func.func @"<lambda>"(%arg0: tensor<2x2x!FHE.eint<8>>) -> tensor<2x2x!FHE.eint<8>> {
%cst = arith.constant dense<[[1, 0], [2, 3]]> : tensor<2x2xi3>
%0 = "FHELinalg.matmul_eint_int"(%arg0, %cst) : (tensor<2x2x!FHE.eint<8>>, tensor<2x2xi3>) -> tensor<2x2x!FHE.eint<8>>
return %0 : tensor<2x2x!FHE.eint<8>>
@@ -2016,7 +2016,7 @@ def test_converter_convert_multi_precision(
"""
module {
func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<6>) -> !FHE.eint<6> {
func.func @"<lambda>"(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<6>) -> !FHE.eint<6> {
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>
return %0 : !FHE.eint<6>
}
@@ -2033,7 +2033,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<6>>, %arg1: tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>> {
func.func @"<lambda>"(%arg0: tensor<3x2x!FHE.eint<6>>, %arg1: tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>> {
%0 = "FHELinalg.mul_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<6>>, tensor<3x2x!FHE.eint<6>>) -> tensor<3x2x!FHE.eint<6>>
return %0 : tensor<3x2x!FHE.eint<6>>
}
@@ -2050,7 +2050,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<7>>, %arg1: tensor<2x!FHE.eint<7>>) -> !FHE.eint<7> {
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<7>>, %arg1: tensor<2x!FHE.eint<7>>) -> !FHE.eint<7> {
%0 = "FHELinalg.dot_eint_eint"(%arg0, %arg1) : (tensor<2x!FHE.eint<7>>, tensor<2x!FHE.eint<7>>) -> !FHE.eint<7>
return %0 : !FHE.eint<7>
}
@@ -2067,7 +2067,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<7>>, %arg1: tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> {
func.func @"<lambda>"(%arg0: tensor<3x2x!FHE.eint<7>>, %arg1: tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> {
%0 = "FHELinalg.matmul_eint_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<7>>, tensor<2x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>>
return %0 : tensor<3x4x!FHE.eint<7>>
}
@@ -2084,7 +2084,7 @@ module {
"""
module {
func.func @main(%arg0: tensor<2x!FHE.eint<6>>, %arg1: !FHE.eint<6>) -> tensor<2x!FHE.eint<6>> {
func.func @"<lambda>"(%arg0: tensor<2x!FHE.eint<6>>, %arg1: !FHE.eint<6>) -> tensor<2x!FHE.eint<6>> {
%c2_i3 = arith.constant 2 : i3
%c16_i7 = arith.constant 16 : i7
%from_elements = tensor.from_elements %c16_i7 : tensor<1xi7>
@@ -2134,7 +2134,7 @@ def test_converter_convert_single_precision(function, parameters, expected_mlir,
"""
module {
func.func @main(%arg0: !FHE.eint<9>, %arg1: !FHE.eint<9>) -> !FHE.eint<9> {
func.func @"<lambda>"(%arg0: !FHE.eint<9>, %arg1: !FHE.eint<9>) -> !FHE.eint<9> {
%0 = "FHE.mul_eint"(%arg0, %arg1) : (!FHE.eint<9>, !FHE.eint<9>) -> !FHE.eint<9>
return %0 : !FHE.eint<9>
}
@@ -2197,7 +2197,7 @@ def test_converter_process_multi_precision(function, parameters, expected_graph,
inputset = helpers.generate_inputset(parameters)
graph = compiler.trace(inputset, configuration)
GraphConverter(configuration).process({"main": graph})
GraphConverter(configuration).process({"<lambda>": graph})
for node in graph.query_nodes():
if "original_bit_width" in node.properties:
del node.properties["original_bit_width"]
@@ -2239,7 +2239,7 @@ def test_converter_process_single_precision(function, parameters, expected_graph
inputset = helpers.generate_inputset(parameters)
graph = compiler.trace(inputset, configuration)
GraphConverter(configuration).process({"main": graph})
GraphConverter(configuration).process({"<lambda>": graph})
for node in graph.query_nodes():
if "original_bit_width" in node.properties:
del node.properties["original_bit_width"]