mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): add more detailed statistics
This commit is contained in:
@@ -18,6 +18,30 @@ namespace concretelang {
|
||||
|
||||
using StringError = ::concretelang::error::StringError;
|
||||
|
||||
enum class PrimitiveOperation {
|
||||
PBS,
|
||||
WOP_PBS,
|
||||
KEY_SWITCH,
|
||||
CLEAR_ADDITION,
|
||||
ENCRYPTED_ADDITION,
|
||||
CLEAR_MULTIPLICATION,
|
||||
ENCRYPTED_NEGATION,
|
||||
};
|
||||
|
||||
enum class KeyType {
|
||||
SECRET,
|
||||
BOOTSTRAP,
|
||||
KEY_SWITCH,
|
||||
PACKING_KEY_SWITCH,
|
||||
};
|
||||
|
||||
struct Statistic {
|
||||
std::string location;
|
||||
PrimitiveOperation operation;
|
||||
std::vector<std::pair<KeyType, size_t>> keys;
|
||||
size_t count;
|
||||
};
|
||||
|
||||
struct CompilationFeedback {
|
||||
double complexity;
|
||||
|
||||
@@ -45,23 +69,8 @@ struct CompilationFeedback {
|
||||
/// @brief crt decomposition of outputs, if crt is not used, empty vectors
|
||||
std::vector<std::vector<int64_t>> crtDecompositionsOfOutputs;
|
||||
|
||||
/// @brief number of programmable bootstraps in the entire circuit
|
||||
uint64_t totalPbsCount = 0;
|
||||
|
||||
/// @brief number of key switches in the entire circuit
|
||||
uint64_t totalKsCount = 0;
|
||||
|
||||
/// @brief number of clear additions in the entire circuit
|
||||
uint64_t totalClearAdditionCount = 0;
|
||||
|
||||
/// @brief number of encrypted additions in the entire circuit
|
||||
uint64_t totalEncryptedAdditionCount = 0;
|
||||
|
||||
/// @brief number of clear multiplications in the entire circuit
|
||||
uint64_t totalClearMultiplicationCount = 0;
|
||||
|
||||
/// @brief number of encrypted negations in the entire circuit
|
||||
uint64_t totalEncryptedNegationCount = 0;
|
||||
/// @brief statistics
|
||||
std::vector<Statistic> statistics;
|
||||
|
||||
/// Fill the sizes from the client parameters.
|
||||
void
|
||||
|
||||
@@ -57,6 +57,7 @@ declare_mlir_python_sources(
|
||||
concrete/compiler/library_compilation_result.py
|
||||
concrete/compiler/library_support.py
|
||||
concrete/compiler/library_lambda.py
|
||||
concrete/compiler/parameter.py
|
||||
concrete/compiler/public_arguments.py
|
||||
concrete/compiler/public_result.py
|
||||
concrete/compiler/evaluation_keys.py
|
||||
|
||||
@@ -109,6 +109,35 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
options.simulate = simulate;
|
||||
});
|
||||
|
||||
pybind11::enum_<mlir::concretelang::PrimitiveOperation>(m,
|
||||
"PrimitiveOperation")
|
||||
.value("PBS", mlir::concretelang::PrimitiveOperation::PBS)
|
||||
.value("WOP_PBS", mlir::concretelang::PrimitiveOperation::WOP_PBS)
|
||||
.value("KEY_SWITCH", mlir::concretelang::PrimitiveOperation::KEY_SWITCH)
|
||||
.value("CLEAR_ADDITION",
|
||||
mlir::concretelang::PrimitiveOperation::CLEAR_ADDITION)
|
||||
.value("ENCRYPTED_ADDITION",
|
||||
mlir::concretelang::PrimitiveOperation::ENCRYPTED_ADDITION)
|
||||
.value("CLEAR_MULTIPLICATION",
|
||||
mlir::concretelang::PrimitiveOperation::CLEAR_MULTIPLICATION)
|
||||
.value("ENCRYPTED_NEGATION",
|
||||
mlir::concretelang::PrimitiveOperation::ENCRYPTED_NEGATION)
|
||||
.export_values();
|
||||
|
||||
pybind11::enum_<mlir::concretelang::KeyType>(m, "KeyType")
|
||||
.value("SECRET", mlir::concretelang::KeyType::SECRET)
|
||||
.value("BOOTSTRAP", mlir::concretelang::KeyType::BOOTSTRAP)
|
||||
.value("KEY_SWITCH", mlir::concretelang::KeyType::KEY_SWITCH)
|
||||
.value("PACKING_KEY_SWITCH",
|
||||
mlir::concretelang::KeyType::PACKING_KEY_SWITCH)
|
||||
.export_values();
|
||||
|
||||
pybind11::class_<mlir::concretelang::Statistic>(m, "Statistic")
|
||||
.def_readonly("operation", &mlir::concretelang::Statistic::operation)
|
||||
.def_readonly("location", &mlir::concretelang::Statistic::location)
|
||||
.def_readonly("keys", &mlir::concretelang::Statistic::keys)
|
||||
.def_readonly("count", &mlir::concretelang::Statistic::count);
|
||||
|
||||
pybind11::class_<mlir::concretelang::CompilationFeedback>(
|
||||
m, "CompilationFeedback")
|
||||
.def_readonly("complexity",
|
||||
@@ -132,22 +161,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def_readonly(
|
||||
"crt_decompositions_of_outputs",
|
||||
&mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs)
|
||||
.def_readonly("total_pbs_count",
|
||||
&mlir::concretelang::CompilationFeedback::totalPbsCount)
|
||||
.def_readonly("total_ks_count",
|
||||
&mlir::concretelang::CompilationFeedback::totalKsCount)
|
||||
.def_readonly(
|
||||
"total_clear_addition_count",
|
||||
&mlir::concretelang::CompilationFeedback::totalClearAdditionCount)
|
||||
.def_readonly(
|
||||
"total_encrypted_addition_count",
|
||||
&mlir::concretelang::CompilationFeedback::totalEncryptedAdditionCount)
|
||||
.def_readonly("total_clear_multiplication_count",
|
||||
&mlir::concretelang::CompilationFeedback::
|
||||
totalClearMultiplicationCount)
|
||||
.def_readonly("total_encrypted_negation_count",
|
||||
&mlir::concretelang::CompilationFeedback::
|
||||
totalEncryptedNegationCount);
|
||||
.def_readonly("statistics",
|
||||
&mlir::concretelang::CompilationFeedback::statistics);
|
||||
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
m, "JITCompilationResult");
|
||||
@@ -320,6 +335,76 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
pybind11::class_<clientlib::KeySetCache>(m, "KeySetCache")
|
||||
.def(pybind11::init<std::string &>());
|
||||
|
||||
pybind11::class_<::concretelang::clientlib::LweSecretKeyParam>(
|
||||
m, "LweSecretKeyParam")
|
||||
.def_readonly("dimension",
|
||||
&::concretelang::clientlib::LweSecretKeyParam::dimension);
|
||||
|
||||
pybind11::class_<::concretelang::clientlib::BootstrapKeyParam>(
|
||||
m, "BootstrapKeyParam")
|
||||
.def_readonly(
|
||||
"input_secret_key_id",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::inputSecretKeyID)
|
||||
.def_readonly(
|
||||
"output_secret_key_id",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::outputSecretKeyID)
|
||||
.def_readonly("level",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::level)
|
||||
.def_readonly("base_log",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::baseLog)
|
||||
.def_readonly(
|
||||
"glwe_dimension",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::glweDimension)
|
||||
.def_readonly("variance",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::variance)
|
||||
.def_readonly(
|
||||
"polynomial_size",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::polynomialSize)
|
||||
.def_readonly(
|
||||
"input_lwe_dimension",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::inputLweDimension);
|
||||
|
||||
pybind11::class_<::concretelang::clientlib::KeyswitchKeyParam>(
|
||||
m, "KeyswitchKeyParam")
|
||||
.def_readonly(
|
||||
"input_secret_key_id",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::inputSecretKeyID)
|
||||
.def_readonly(
|
||||
"output_secret_key_id",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::outputSecretKeyID)
|
||||
.def_readonly("level",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::level)
|
||||
.def_readonly("base_log",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::baseLog)
|
||||
.def_readonly("variance",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::variance);
|
||||
|
||||
pybind11::class_<::concretelang::clientlib::PackingKeyswitchKeyParam>(
|
||||
m, "PackingKeyswitchKeyParam")
|
||||
.def_readonly("input_secret_key_id",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::
|
||||
inputSecretKeyID)
|
||||
.def_readonly("output_secret_key_id",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::
|
||||
outputSecretKeyID)
|
||||
.def_readonly("level",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::level)
|
||||
.def_readonly(
|
||||
"base_log",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::baseLog)
|
||||
.def_readonly(
|
||||
"glwe_dimension",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::glweDimension)
|
||||
.def_readonly(
|
||||
"polynomial_size",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::polynomialSize)
|
||||
.def_readonly("input_lwe_dimension",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::
|
||||
inputLweDimension)
|
||||
.def_readonly(
|
||||
"variance",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::variance);
|
||||
|
||||
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters")
|
||||
.def_static("deserialize",
|
||||
[](const pybind11::bytes &buffer) {
|
||||
@@ -353,7 +438,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
}
|
||||
}
|
||||
return result;
|
||||
});
|
||||
})
|
||||
.def_readonly("secret_keys",
|
||||
&mlir::concretelang::ClientParameters::secretKeys)
|
||||
.def_readonly("bootstrap_keys",
|
||||
&mlir::concretelang::ClientParameters::bootstrapKeys)
|
||||
.def_readonly("keyswitch_keys",
|
||||
&mlir::concretelang::ClientParameters::keyswitchKeys)
|
||||
.def_readonly(
|
||||
"packing_keyswitch_keys",
|
||||
&mlir::concretelang::ClientParameters::packingKeyswitchKeys);
|
||||
|
||||
pybind11::class_<clientlib::KeySet>(m, "KeySet")
|
||||
.def_static("deserialize",
|
||||
|
||||
@@ -39,6 +39,7 @@ from .value_decrypter import ValueDecrypter
|
||||
from .value_exporter import ValueExporter
|
||||
from .simulated_value_decrypter import SimulatedValueDecrypter
|
||||
from .simulated_value_exporter import SimulatedValueExporter
|
||||
from .parameter import Parameter
|
||||
|
||||
|
||||
def init_dfr():
|
||||
|
||||
@@ -3,13 +3,19 @@
|
||||
|
||||
"""Compilation feedback."""
|
||||
|
||||
from typing import Dict, Set
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error,too-many-instance-attributes
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
CompilationFeedback as _CompilationFeedback,
|
||||
KeyType,
|
||||
PrimitiveOperation,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
from .client_parameters import ClientParameters
|
||||
from .parameter import Parameter
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
@@ -41,19 +47,153 @@ class CompilationFeedback(WrapperCpp):
|
||||
self.crt_decompositions_of_outputs = (
|
||||
compilation_feedback.crt_decompositions_of_outputs
|
||||
)
|
||||
self.total_pbs_count = compilation_feedback.total_pbs_count
|
||||
self.total_ks_count = compilation_feedback.total_ks_count
|
||||
self.total_clear_addition_count = (
|
||||
compilation_feedback.total_clear_addition_count
|
||||
)
|
||||
self.total_encrypted_addition_count = (
|
||||
compilation_feedback.total_encrypted_addition_count
|
||||
)
|
||||
self.total_clear_multiplication_count = (
|
||||
compilation_feedback.total_clear_multiplication_count
|
||||
)
|
||||
self.total_encrypted_negation_count = (
|
||||
compilation_feedback.total_encrypted_negation_count
|
||||
)
|
||||
self.statistics = compilation_feedback.statistics
|
||||
|
||||
super().__init__(compilation_feedback)
|
||||
|
||||
def count(self, *, operations: Set[PrimitiveOperation]) -> int:
|
||||
"""
|
||||
Count the amount of specified operations in the program.
|
||||
|
||||
Args:
|
||||
operations (Set[PrimitiveOperation]):
|
||||
set of operations used to filter the statistics
|
||||
|
||||
Returns:
|
||||
int:
|
||||
number of specified operations in the program
|
||||
"""
|
||||
|
||||
return sum(
|
||||
statistic.count
|
||||
for statistic in self.statistics
|
||||
if statistic.operation in operations
|
||||
)
|
||||
|
||||
def count_per_parameter(
|
||||
self,
|
||||
*,
|
||||
operations: Set[PrimitiveOperation],
|
||||
key_types: Set[KeyType],
|
||||
client_parameters: ClientParameters,
|
||||
) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Count the amount of specified operations in the program and group by parameters.
|
||||
|
||||
Args:
|
||||
operations (Set[PrimitiveOperation]):
|
||||
set of operations used to filter the statistics
|
||||
|
||||
key_types (Set[KeyType]):
|
||||
set of key types used to filter the statistics
|
||||
|
||||
client_parameters (ClientParameters):
|
||||
client parameters required for grouping by parameters
|
||||
|
||||
Returns:
|
||||
Dict[Parameter, int]:
|
||||
number of specified operations per parameter in the program
|
||||
"""
|
||||
|
||||
result = {}
|
||||
for statistic in self.statistics:
|
||||
if statistic.operation not in operations:
|
||||
continue
|
||||
|
||||
for key_type, key_index in statistic.keys:
|
||||
if key_type not in key_types:
|
||||
continue
|
||||
|
||||
parameter = Parameter(client_parameters, key_type, key_index)
|
||||
if parameter not in result:
|
||||
result[parameter] = 0
|
||||
result[parameter] += statistic.count
|
||||
|
||||
return result
|
||||
|
||||
def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int]:
|
||||
"""
|
||||
Count the amount of specified operations in the program and group by tags.
|
||||
|
||||
Args:
|
||||
operations (Set[PrimitiveOperation]):
|
||||
set of operations used to filter the statistics
|
||||
|
||||
Returns:
|
||||
Dict[str, int]:
|
||||
number of specified operations per tag in the program
|
||||
"""
|
||||
|
||||
result = {}
|
||||
for statistic in self.statistics:
|
||||
if statistic.operation not in operations:
|
||||
continue
|
||||
|
||||
file_and_maybe_tag = statistic.location.split("@")
|
||||
tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip()
|
||||
|
||||
tag_components = tag.split(".")
|
||||
for i in range(1, len(tag_components) + 1):
|
||||
current_tag = ".".join(tag_components[0:i])
|
||||
if current_tag == "":
|
||||
continue
|
||||
|
||||
if current_tag not in result:
|
||||
result[current_tag] = 0
|
||||
|
||||
result[current_tag] += statistic.count
|
||||
|
||||
return result
|
||||
|
||||
def count_per_tag_per_parameter(
|
||||
self,
|
||||
*,
|
||||
operations: Set[PrimitiveOperation],
|
||||
key_types: Set[KeyType],
|
||||
client_parameters: ClientParameters,
|
||||
) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Count the amount of specified operations in the program and group by tags and parameters.
|
||||
|
||||
Args:
|
||||
operations (Set[PrimitiveOperation]):
|
||||
set of operations used to filter the statistics
|
||||
|
||||
key_types (Set[KeyType]):
|
||||
set of key types used to filter the statistics
|
||||
|
||||
client_parameters (ClientParameters):
|
||||
client parameters required for grouping by parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[Parameter, int]]:
|
||||
number of specified operations per tag per parameter in the program
|
||||
"""
|
||||
|
||||
result: Dict[str, Dict[int, int]] = {}
|
||||
for statistic in self.statistics:
|
||||
if statistic.operation not in operations:
|
||||
continue
|
||||
|
||||
file_and_maybe_tag = statistic.location.split("@")
|
||||
tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip()
|
||||
|
||||
tag_components = tag.split(".")
|
||||
for i in range(1, len(tag_components) + 1):
|
||||
current_tag = ".".join(tag_components[0:i])
|
||||
if current_tag == "":
|
||||
continue
|
||||
|
||||
if current_tag not in result:
|
||||
result[current_tag] = {}
|
||||
|
||||
for key_type, key_index in statistic.keys:
|
||||
if key_type not in key_types:
|
||||
continue
|
||||
|
||||
parameter = Parameter(client_parameters, key_type, key_index)
|
||||
if parameter not in result[current_tag]:
|
||||
result[current_tag][parameter] = 0
|
||||
result[current_tag][parameter] += statistic.count
|
||||
|
||||
return result
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Parameter.
|
||||
"""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
|
||||
from typing import Union
|
||||
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
LweSecretKeyParam,
|
||||
BootstrapKeyParam,
|
||||
KeyswitchKeyParam,
|
||||
PackingKeyswitchKeyParam,
|
||||
KeyType,
|
||||
)
|
||||
|
||||
from .client_parameters import ClientParameters
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
|
||||
class Parameter:
|
||||
"""
|
||||
An FHE parameter.
|
||||
"""
|
||||
|
||||
_inner: Union[
|
||||
LweSecretKeyParam,
|
||||
BootstrapKeyParam,
|
||||
KeyswitchKeyParam,
|
||||
PackingKeyswitchKeyParam,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_parameters: ClientParameters,
|
||||
key_type: KeyType,
|
||||
key_index: int,
|
||||
):
|
||||
if key_type == KeyType.SECRET:
|
||||
self._inner = client_parameters.cpp().secret_keys[key_index]
|
||||
elif key_type == KeyType.BOOTSTRAP:
|
||||
self._inner = client_parameters.cpp().bootstrap_keys[key_index]
|
||||
elif key_type == KeyType.KEY_SWITCH:
|
||||
self._inner = client_parameters.cpp().keyswitch_keys[key_index]
|
||||
elif key_type == KeyType.PACKING_KEY_SWITCH:
|
||||
self._inner = client_parameters.cpp().packing_keyswitch_keys[key_index]
|
||||
else:
|
||||
raise ValueError("invalid key type")
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._inner, item)
|
||||
|
||||
def __repr__(self):
|
||||
param = self._inner
|
||||
|
||||
if isinstance(param, LweSecretKeyParam):
|
||||
result = f"LweSecretKeyParam(" f"dimension={param.dimension}" f")"
|
||||
|
||||
elif isinstance(param, BootstrapKeyParam):
|
||||
result = (
|
||||
f"BootstrapKeyParam("
|
||||
f"polynomial_size={param.polynomial_size}, "
|
||||
f"glwe_dimension={param.glwe_dimension}, "
|
||||
f"input_lwe_dimension={param.input_lwe_dimension}, "
|
||||
f"level={param.level}, "
|
||||
f"base_log={param.base_log}, "
|
||||
f"variance={param.variance}"
|
||||
f")"
|
||||
)
|
||||
|
||||
elif isinstance(param, KeyswitchKeyParam):
|
||||
result = (
|
||||
f"KeyswitchKeyParam("
|
||||
f"level={param.level}, "
|
||||
f"base_log={param.base_log}, "
|
||||
f"variance={param.variance}"
|
||||
f")"
|
||||
)
|
||||
|
||||
elif isinstance(param, PackingKeyswitchKeyParam):
|
||||
result = (
|
||||
f"PackingKeyswitchKeyParam("
|
||||
f"polynomial_size={param.polynomial_size}, "
|
||||
f"glwe_dimension={param.glwe_dimension}, "
|
||||
f"input_lwe_dimension={param.input_lwe_dimension}"
|
||||
f"level={param.level}, "
|
||||
f"base_log={param.base_log}, "
|
||||
f"variance={param.variance}"
|
||||
f")"
|
||||
)
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
return result
|
||||
|
||||
def __str__(self):
|
||||
return repr(self)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
@@ -301,6 +301,7 @@ const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[
|
||||
"SDFGDialect",
|
||||
"ExtractSDFGOps",
|
||||
"SDFGToStreamEmulator",
|
||||
"TFHEDialectAnalysis",
|
||||
];
|
||||
|
||||
fn main() {
|
||||
|
||||
@@ -10,6 +10,17 @@ using namespace mlir;
|
||||
|
||||
using TFHE::ExtractTFHEStatisticsPass;
|
||||
|
||||
// #########
|
||||
// Utilities
|
||||
// #########
|
||||
|
||||
template <typename Op> std::string locationOf(Op op) {
|
||||
auto location = std::string();
|
||||
auto locationStream = llvm::raw_string_ostream(location);
|
||||
op.getLoc()->print(locationStream);
|
||||
return location.substr(5, location.size() - 2 - 5); // remove loc(" and ")
|
||||
}
|
||||
|
||||
// #######
|
||||
// scf.for
|
||||
// #######
|
||||
@@ -103,7 +114,24 @@ static std::optional<StringError> on_exit(scf::ForOp &op,
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::AddGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
pass.feedback.totalEncryptedAdditionCount += pass.iterations;
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -113,7 +141,24 @@ static std::optional<StringError> on_enter(TFHE::AddGLWEOp &op,
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::AddGLWEIntOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
pass.feedback.totalClearAdditionCount += pass.iterations;
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -123,7 +168,24 @@ static std::optional<StringError> on_enter(TFHE::AddGLWEIntOp &op,
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::BootstrapGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
pass.feedback.totalPbsCount += pass.iterations;
|
||||
auto bsk = op.getKey();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -133,7 +195,24 @@ static std::optional<StringError> on_enter(TFHE::BootstrapGLWEOp &op,
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::KeySwitchGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
pass.feedback.totalKsCount += pass.iterations;
|
||||
auto ksk = op.getKey();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::KEY_SWITCH;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -143,7 +222,24 @@ static std::optional<StringError> on_enter(TFHE::KeySwitchGLWEOp &op,
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::MulGLWEIntOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
pass.feedback.totalClearMultiplicationCount += pass.iterations;
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -153,7 +249,24 @@ static std::optional<StringError> on_enter(TFHE::MulGLWEIntOp &op,
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::NegGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
pass.feedback.totalEncryptedNegationCount += pass.iterations;
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -163,9 +276,34 @@ static std::optional<StringError> on_enter(TFHE::NegGLWEOp &op,
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::SubGLWEIntOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
auto resultingKey = op.getType().getKey().getNormalized();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
|
||||
keys.push_back(key);
|
||||
|
||||
// clear - encrypted = clear + neg(encrypted)
|
||||
pass.feedback.totalEncryptedNegationCount += pass.iterations;
|
||||
pass.feedback.totalClearAdditionCount += pass.iterations;
|
||||
|
||||
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -175,7 +313,32 @@ static std::optional<StringError> on_enter(TFHE::SubGLWEIntOp &op,
|
||||
|
||||
static std::optional<StringError> on_enter(TFHE::WopPBSGLWEOp &op,
|
||||
ExtractTFHEStatisticsPass &pass) {
|
||||
pass.feedback.totalPbsCount += pass.iterations;
|
||||
auto bsk = op.getBsk();
|
||||
auto ksk = op.getKsk();
|
||||
auto pksk = op.getPksk();
|
||||
|
||||
auto location = locationOf(op);
|
||||
auto operation = PrimitiveOperation::WOP_PBS;
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
auto count = pass.iterations;
|
||||
|
||||
std::pair<KeyType, size_t> key =
|
||||
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex());
|
||||
keys.push_back(key);
|
||||
|
||||
pass.feedback.statistics.push_back(Statistic{
|
||||
location,
|
||||
operation,
|
||||
keys,
|
||||
count,
|
||||
});
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
|
||||
#include "concretelang/Support/CompilationFeedback.h"
|
||||
@@ -62,13 +61,6 @@ void CompilationFeedback::fillFromClientParameters(
|
||||
}
|
||||
crtDecompositionsOfOutputs.push_back(decomposition);
|
||||
}
|
||||
// Stats
|
||||
totalPbsCount = 0;
|
||||
totalKsCount = 0;
|
||||
totalClearAdditionCount = 0;
|
||||
totalEncryptedAdditionCount = 0;
|
||||
totalClearMultiplicationCount = 0;
|
||||
totalEncryptedNegationCount = 0;
|
||||
}
|
||||
|
||||
outcome::checked<CompilationFeedback, StringError>
|
||||
@@ -81,7 +73,7 @@ CompilationFeedback::load(std::string jsonPath) {
|
||||
}
|
||||
auto expectedCompFeedback = llvm::json::parse<CompilationFeedback>(content);
|
||||
if (auto err = expectedCompFeedback.takeError()) {
|
||||
return StringError("Cannot open client parameters: ")
|
||||
return StringError("Cannot open compilation feedback: ")
|
||||
<< llvm::toString(std::move(err)) << "\n"
|
||||
<< content << "\n";
|
||||
}
|
||||
@@ -99,34 +91,166 @@ llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) {
|
||||
{"totalInputsSize", v.totalInputsSize},
|
||||
{"totalOutputsSize", v.totalOutputsSize},
|
||||
{"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs},
|
||||
{"totalPbsCount", v.totalPbsCount},
|
||||
{"totalKsCount", v.totalKsCount},
|
||||
{"totalClearAdditionCount", v.totalClearAdditionCount},
|
||||
{"totalEncryptedAdditionCount", v.totalEncryptedAdditionCount},
|
||||
{"totalClearMultiplicationCount", v.totalClearMultiplicationCount},
|
||||
{"totalEncryptedNegationCount", v.totalEncryptedNegationCount},
|
||||
};
|
||||
|
||||
auto statisticsJson = llvm::json::Array();
|
||||
for (auto statistic : v.statistics) {
|
||||
auto statisticJson = llvm::json::Object();
|
||||
statisticJson.insert({"location", statistic.location});
|
||||
switch (statistic.operation) {
|
||||
case PrimitiveOperation::PBS:
|
||||
statisticJson.insert({"operation", "PBS"});
|
||||
break;
|
||||
case PrimitiveOperation::WOP_PBS:
|
||||
statisticJson.insert({"operation", "WOP_PBS"});
|
||||
break;
|
||||
case PrimitiveOperation::KEY_SWITCH:
|
||||
statisticJson.insert({"operation", "KEY_SWITCH"});
|
||||
break;
|
||||
case PrimitiveOperation::CLEAR_ADDITION:
|
||||
statisticJson.insert({"operation", "CLEAR_ADDITION"});
|
||||
break;
|
||||
case PrimitiveOperation::ENCRYPTED_ADDITION:
|
||||
statisticJson.insert({"operation", "ENCRYPTED_ADDITION"});
|
||||
break;
|
||||
case PrimitiveOperation::CLEAR_MULTIPLICATION:
|
||||
statisticJson.insert({"operation", "CLEAR_MULTIPLICATION"});
|
||||
break;
|
||||
case PrimitiveOperation::ENCRYPTED_NEGATION:
|
||||
statisticJson.insert({"operation", "ENCRYPTED_NEGATION"});
|
||||
break;
|
||||
}
|
||||
auto keysJson = llvm::json::Array();
|
||||
for (auto &key : statistic.keys) {
|
||||
KeyType type = key.first;
|
||||
size_t index = key.second;
|
||||
|
||||
auto keyJson = llvm::json::Array();
|
||||
switch (type) {
|
||||
case KeyType::SECRET:
|
||||
keyJson.push_back("SECRET");
|
||||
break;
|
||||
case KeyType::BOOTSTRAP:
|
||||
keyJson.push_back("BOOTSTRAP");
|
||||
break;
|
||||
case KeyType::KEY_SWITCH:
|
||||
keyJson.push_back("KEY_SWITCH");
|
||||
break;
|
||||
case KeyType::PACKING_KEY_SWITCH:
|
||||
keyJson.push_back("PACKING_KEY_SWITCH");
|
||||
break;
|
||||
}
|
||||
keyJson.push_back((int64_t)index);
|
||||
|
||||
keysJson.push_back(std::move(keyJson));
|
||||
}
|
||||
statisticJson.insert({"keys", std::move(keysJson)});
|
||||
statisticJson.insert({"count", (int64_t)statistic.count});
|
||||
|
||||
statisticsJson.push_back(std::move(statisticJson));
|
||||
}
|
||||
object.insert({"statistics", std::move(statisticsJson)});
|
||||
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j,
|
||||
mlir::concretelang::CompilationFeedback &v, llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("complexity", v.complexity) && O.map("pError", v.pError) &&
|
||||
O.map("globalPError", v.globalPError) &&
|
||||
O.map("totalSecretKeysSize", v.totalSecretKeysSize) &&
|
||||
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
|
||||
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
|
||||
O.map("totalInputsSize", v.totalInputsSize) &&
|
||||
O.map("totalOutputsSize", v.totalOutputsSize) &&
|
||||
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs) &&
|
||||
O.map("totalPbsCount", v.totalPbsCount) &&
|
||||
O.map("totalKsCount", v.totalKsCount) &&
|
||||
O.map("totalClearAdditionCount", v.totalClearAdditionCount) &&
|
||||
O.map("totalEncryptedAdditionCount", v.totalEncryptedAdditionCount) &&
|
||||
O.map("totalClearMultiplicationCount",
|
||||
v.totalClearMultiplicationCount) &&
|
||||
O.map("totalEncryptedNegationCount", v.totalEncryptedNegationCount);
|
||||
|
||||
bool is_success =
|
||||
O && O.map("complexity", v.complexity) && O.map("pError", v.pError) &&
|
||||
O.map("globalPError", v.globalPError) &&
|
||||
O.map("totalSecretKeysSize", v.totalSecretKeysSize) &&
|
||||
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
|
||||
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
|
||||
O.map("totalInputsSize", v.totalInputsSize) &&
|
||||
O.map("totalOutputsSize", v.totalOutputsSize) &&
|
||||
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs);
|
||||
|
||||
if (!is_success) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto object = j.getAsObject();
|
||||
if (!object) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto statistics = object->getArray("statistics");
|
||||
if (!statistics) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto statisticValue : *statistics) {
|
||||
auto statistic = statisticValue.getAsObject();
|
||||
if (!statistic) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto location = statistic->getString("location");
|
||||
auto operationStr = statistic->getString("operation");
|
||||
auto keysArray = statistic->getArray("keys");
|
||||
auto count = statistic->getInteger("count");
|
||||
|
||||
if (!operationStr || !location || !keysArray || !count) {
|
||||
return false;
|
||||
}
|
||||
|
||||
PrimitiveOperation operation;
|
||||
if (operationStr.value() == "PBS") {
|
||||
operation = PrimitiveOperation::PBS;
|
||||
} else if (operationStr.value() == "KEY_SWITCH") {
|
||||
operation = PrimitiveOperation::KEY_SWITCH;
|
||||
} else if (operationStr.value() == "WOP_PBS") {
|
||||
operation = PrimitiveOperation::WOP_PBS;
|
||||
} else if (operationStr.value() == "CLEAR_ADDITION") {
|
||||
operation = PrimitiveOperation::CLEAR_ADDITION;
|
||||
} else if (operationStr.value() == "ENCRYPTED_ADDITION") {
|
||||
operation = PrimitiveOperation::ENCRYPTED_ADDITION;
|
||||
} else if (operationStr.value() == "CLEAR_MULTIPLICATION") {
|
||||
operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
|
||||
} else if (operationStr.value() == "ENCRYPTED_NEGATION") {
|
||||
operation = PrimitiveOperation::ENCRYPTED_NEGATION;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto keys = std::vector<std::pair<KeyType, size_t>>();
|
||||
for (auto keyValue : *keysArray) {
|
||||
llvm::json::Array *keyArray = keyValue.getAsArray();
|
||||
if (!keyArray || keyArray->size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto typeStr = keyArray->front().getAsString();
|
||||
auto index = keyArray->back().getAsInteger();
|
||||
|
||||
if (!typeStr || !index) {
|
||||
return false;
|
||||
}
|
||||
|
||||
KeyType type;
|
||||
if (typeStr.value() == "SECRET") {
|
||||
type = KeyType::SECRET;
|
||||
} else if (typeStr.value() == "BOOTSTRAP") {
|
||||
type = KeyType::BOOTSTRAP;
|
||||
} else if (typeStr.value() == "KEY_SWITCH") {
|
||||
type = KeyType::KEY_SWITCH;
|
||||
} else if (typeStr.value() == "PACKING_KEY_SWITCH") {
|
||||
type = KeyType::PACKING_KEY_SWITCH;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
keys.push_back(std::make_pair(type, (size_t)*index));
|
||||
}
|
||||
|
||||
v.statistics.push_back(
|
||||
Statistic{location->str(), operation, keys, (uint64_t)*count});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -4,7 +4,7 @@ Concrete.
|
||||
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
|
||||
from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult
|
||||
from concrete.compiler import EvaluationKeys, Parameter, PublicArguments, PublicResult
|
||||
|
||||
from .compilation import (
|
||||
DEFAULT_GLOBAL_P_ERROR,
|
||||
|
||||
@@ -4,10 +4,15 @@ Declaration of `Circuit` class.
|
||||
|
||||
# pylint: disable=import-error,no-member,no-name-in-module
|
||||
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from concrete.compiler import CompilationContext, SimulatedValueDecrypter, SimulatedValueExporter
|
||||
from concrete.compiler import (
|
||||
CompilationContext,
|
||||
Parameter,
|
||||
SimulatedValueDecrypter,
|
||||
SimulatedValueExporter,
|
||||
)
|
||||
from mlir.ir import Module as MlirModule
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
@@ -266,118 +271,15 @@ class Circuit:
|
||||
if hasattr(self, "server"): # pragma: no cover
|
||||
self.server.cleanup()
|
||||
|
||||
@property
|
||||
def size_of_secret_keys(self) -> int:
|
||||
"""
|
||||
Get size of the secret keys of the circuit.
|
||||
"""
|
||||
return self._statistic("size_of_secret_keys")
|
||||
# Properties
|
||||
|
||||
@property
|
||||
def size_of_bootstrap_keys(self) -> int:
|
||||
def _property(self, name: str) -> Any:
|
||||
"""
|
||||
Get size of the bootstrap keys of the circuit.
|
||||
"""
|
||||
return self._statistic("size_of_bootstrap_keys")
|
||||
|
||||
@property
|
||||
def size_of_keyswitch_keys(self) -> int:
|
||||
"""
|
||||
Get size of the key switch keys of the circuit.
|
||||
"""
|
||||
return self._statistic("size_of_keyswitch_keys")
|
||||
|
||||
@property
|
||||
def size_of_inputs(self) -> int:
|
||||
"""
|
||||
Get size of the inputs of the circuit.
|
||||
"""
|
||||
return self._statistic("size_of_inputs")
|
||||
|
||||
@property
|
||||
def size_of_outputs(self) -> int:
|
||||
"""
|
||||
Get size of the outputs of the circuit.
|
||||
"""
|
||||
return self._statistic("size_of_outputs")
|
||||
|
||||
@property
|
||||
def p_error(self) -> int:
|
||||
"""
|
||||
Get probability of error for each simple TLU (on a scalar).
|
||||
"""
|
||||
return self._statistic("p_error")
|
||||
|
||||
@property
|
||||
def global_p_error(self) -> int:
|
||||
"""
|
||||
Get the probability of having at least one simple TLU error during the entire execution.
|
||||
"""
|
||||
return self._statistic("global_p_error")
|
||||
|
||||
@property
|
||||
def complexity(self) -> float:
|
||||
"""
|
||||
Get complexity of the circuit.
|
||||
"""
|
||||
return self._statistic("complexity")
|
||||
|
||||
@property
|
||||
def total_pbs_count(self) -> int:
|
||||
"""
|
||||
Get the total number of programmable bootstraps in the circuit.
|
||||
"""
|
||||
return self._statistic("total_pbs_count")
|
||||
|
||||
@property
|
||||
def total_ks_count(self) -> int:
|
||||
"""
|
||||
Get the total number of key switches in the circuit.
|
||||
"""
|
||||
return self._statistic("total_ks_count")
|
||||
|
||||
@property
|
||||
def total_clear_addition_count(self) -> int:
|
||||
"""
|
||||
Get the total number of clear additions in the circuit.
|
||||
"""
|
||||
return self._statistic("total_clear_addition_count")
|
||||
|
||||
@property
|
||||
def total_encrypted_addition_count(self) -> int:
|
||||
"""
|
||||
Get the total number of encrypted additions in the circuit.
|
||||
"""
|
||||
return self._statistic("total_encrypted_addition_count")
|
||||
|
||||
@property
|
||||
def total_clear_multiplication_count(self) -> int:
|
||||
"""
|
||||
Get the total number of clear multiplications in the circuit.
|
||||
"""
|
||||
return self._statistic("total_clear_multiplication_count")
|
||||
|
||||
@property
|
||||
def total_encrypted_negation_count(self) -> int:
|
||||
"""
|
||||
Get the total number of encrypted negations in the circuit.
|
||||
"""
|
||||
return self._statistic("total_encrypted_negation_count")
|
||||
|
||||
@property
|
||||
def statistics(self) -> dict:
|
||||
"""
|
||||
Get all circuit statistics in a dict.
|
||||
"""
|
||||
return self._statistic("statistics")
|
||||
|
||||
def _statistic(self, name: str) -> Any:
|
||||
"""
|
||||
Get a statistic of the circuit by name.
|
||||
Get a property of the circuit by name.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
name of the statistic
|
||||
name of the property
|
||||
|
||||
Returns:
|
||||
Any:
|
||||
@@ -391,3 +293,282 @@ class Circuit:
|
||||
self.enable_fhe_execution() # pragma: no cover
|
||||
|
||||
return getattr(self.server, name)
|
||||
|
||||
@property
|
||||
def size_of_secret_keys(self) -> int:
|
||||
"""
|
||||
Get size of the secret keys of the circuit.
|
||||
"""
|
||||
return self._property("size_of_secret_keys") # pragma: no cover
|
||||
|
||||
@property
|
||||
def size_of_bootstrap_keys(self) -> int:
|
||||
"""
|
||||
Get size of the bootstrap keys of the circuit.
|
||||
"""
|
||||
return self._property("size_of_bootstrap_keys") # pragma: no cover
|
||||
|
||||
@property
|
||||
def size_of_keyswitch_keys(self) -> int:
|
||||
"""
|
||||
Get size of the key switch keys of the circuit.
|
||||
"""
|
||||
return self._property("size_of_keyswitch_keys") # pragma: no cover
|
||||
|
||||
@property
|
||||
def size_of_inputs(self) -> int:
|
||||
"""
|
||||
Get size of the inputs of the circuit.
|
||||
"""
|
||||
return self._property("size_of_inputs") # pragma: no cover
|
||||
|
||||
@property
|
||||
def size_of_outputs(self) -> int:
|
||||
"""
|
||||
Get size of the outputs of the circuit.
|
||||
"""
|
||||
return self._property("size_of_outputs") # pragma: no cover
|
||||
|
||||
@property
|
||||
def p_error(self) -> int:
|
||||
"""
|
||||
Get probability of error for each simple TLU (on a scalar).
|
||||
"""
|
||||
return self._property("p_error") # pragma: no cover
|
||||
|
||||
@property
|
||||
def global_p_error(self) -> int:
|
||||
"""
|
||||
Get the probability of having at least one simple TLU error during the entire execution.
|
||||
"""
|
||||
return self._property("global_p_error") # pragma: no cover
|
||||
|
||||
@property
|
||||
def complexity(self) -> float:
|
||||
"""
|
||||
Get complexity of the circuit.
|
||||
"""
|
||||
return self._property("complexity") # pragma: no cover
|
||||
|
||||
# Programmable Bootstrap Statistics
|
||||
|
||||
@property
|
||||
def programmable_bootstrap_count(self) -> int:
|
||||
"""
|
||||
Get the number of programmable bootstraps in the circuit.
|
||||
"""
|
||||
return self._property("programmable_bootstrap_count") # pragma: no cover
|
||||
|
||||
@property
|
||||
def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of programmable bootstraps per bit width in the circuit.
|
||||
"""
|
||||
return self._property("programmable_bootstrap_count_per_parameter") # pragma: no cover
|
||||
|
||||
@property
|
||||
def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of programmable bootstraps per tag in the circuit.
|
||||
"""
|
||||
return self._property("programmable_bootstrap_count_per_tag") # pragma: no cover
|
||||
|
||||
@property
|
||||
def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[int, int]]:
|
||||
"""
|
||||
Get the number of programmable bootstraps per tag per bit width in the circuit.
|
||||
"""
|
||||
return self._property(
|
||||
"programmable_bootstrap_count_per_tag_per_parameter"
|
||||
) # pragma: no cover
|
||||
|
||||
# Key Switch Statistics
|
||||
|
||||
@property
|
||||
def key_switch_count(self) -> int:
|
||||
"""
|
||||
Get the number of key switches in the circuit.
|
||||
"""
|
||||
return self._property("key_switch_count") # pragma: no cover
|
||||
|
||||
@property
|
||||
def key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of key switches per parameter in the circuit.
|
||||
"""
|
||||
return self._property("key_switch_count_per_parameter") # pragma: no cover
|
||||
|
||||
@property
|
||||
def key_switch_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of key switches per tag in the circuit.
|
||||
"""
|
||||
return self._property("key_switch_count_per_tag") # pragma: no cover
|
||||
|
||||
@property
|
||||
def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of key switches per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property("key_switch_count_per_tag_per_parameter") # pragma: no cover
|
||||
|
||||
# Packing Key Switch Statistics
|
||||
|
||||
@property
|
||||
def packing_key_switch_count(self) -> int:
|
||||
"""
|
||||
Get the number of packing key switches in the circuit.
|
||||
"""
|
||||
return self._property("packing_key_switch_count") # pragma: no cover
|
||||
|
||||
@property
|
||||
def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of packing key switches per parameter in the circuit.
|
||||
"""
|
||||
return self._property("packing_key_switch_count_per_parameter") # pragma: no cover
|
||||
|
||||
@property
|
||||
def packing_key_switch_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of packing key switches per tag in the circuit.
|
||||
"""
|
||||
return self._property("packing_key_switch_count_per_tag") # pragma: no cover
|
||||
|
||||
@property
|
||||
def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of packing key switches per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property("packing_key_switch_count_per_tag_per_parameter") # pragma: no cover
|
||||
|
||||
# Clear Addition Statistics
|
||||
|
||||
@property
|
||||
def clear_addition_count(self) -> int:
|
||||
"""
|
||||
Get the number of clear additions in the circuit.
|
||||
"""
|
||||
return self._property("clear_addition_count") # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of clear additions per parameter in the circuit.
|
||||
"""
|
||||
return self._property("clear_addition_count_per_parameter") # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_addition_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of clear additions per tag in the circuit.
|
||||
"""
|
||||
return self._property("clear_addition_count_per_tag") # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of clear additions per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property("clear_addition_count_per_tag_per_parameter") # pragma: no cover
|
||||
|
||||
# Encrypted Addition Statistics
|
||||
|
||||
@property
|
||||
def encrypted_addition_count(self) -> int:
|
||||
"""
|
||||
Get the number of encrypted additions in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_addition_count") # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of encrypted additions per parameter in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_addition_count_per_parameter") # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_addition_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of encrypted additions per tag in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_addition_count_per_tag") # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of encrypted additions per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_addition_count_per_tag_per_parameter") # pragma: no cover
|
||||
|
||||
# Clear Multiplication Statistics
|
||||
|
||||
@property
|
||||
def clear_multiplication_count(self) -> int:
|
||||
"""
|
||||
Get the number of clear multiplications in the circuit.
|
||||
"""
|
||||
return self._property("clear_multiplication_count") # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of clear multiplications per parameter in the circuit.
|
||||
"""
|
||||
return self._property("clear_multiplication_count_per_parameter") # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_multiplication_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of clear multiplications per tag in the circuit.
|
||||
"""
|
||||
return self._property("clear_multiplication_count_per_tag") # pragma: no cover
|
||||
|
||||
@property
|
||||
def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of clear multiplications per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property(
|
||||
"clear_multiplication_count_per_tag_per_parameter"
|
||||
) # pragma: no cover
|
||||
|
||||
# Encrypted Negation Statistics
|
||||
|
||||
@property
|
||||
def encrypted_negation_count(self) -> int:
|
||||
"""
|
||||
Get the number of encrypted negations in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_negation_count") # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of encrypted negations per parameter in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_negation_count_per_parameter") # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_negation_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of encrypted negations per tag in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_negation_count_per_tag") # pragma: no cover
|
||||
|
||||
@property
|
||||
def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of encrypted negations per tag per parameter in the circuit.
|
||||
"""
|
||||
return self._property("encrypted_negation_count_per_tag_per_parameter") # pragma: no cover
|
||||
|
||||
# All Statistics
|
||||
|
||||
@property
|
||||
def statistics(self) -> Dict:
|
||||
"""
|
||||
Get all statistics of the circuit.
|
||||
"""
|
||||
return self._property("statistics") # pragma: no cover
|
||||
|
||||
@@ -485,7 +485,7 @@ class Compiler:
|
||||
)
|
||||
|
||||
columns = 0
|
||||
if show_graph or show_mlir or show_optimizer:
|
||||
if show_graph or show_mlir or show_optimizer or show_statistics:
|
||||
graph = (
|
||||
self.graph.format()
|
||||
if self.configuration.verbose or self.configuration.show_graph
|
||||
@@ -556,8 +556,27 @@ class Compiler:
|
||||
|
||||
print("Statistics")
|
||||
print("-" * columns)
|
||||
for name, value in circuit.statistics.items():
|
||||
print(f"{name}: {value}")
|
||||
|
||||
def pretty(d, indent=0): # pragma: no cover
|
||||
if indent > 0:
|
||||
print("{")
|
||||
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict) and len(value) == 0:
|
||||
continue
|
||||
|
||||
print(" " * indent + str(key) + ": ", end="")
|
||||
|
||||
if isinstance(value, dict):
|
||||
pretty(value, indent + 1)
|
||||
else:
|
||||
print(value)
|
||||
|
||||
if indent > 0:
|
||||
print(" " * (indent - 1) + "}")
|
||||
|
||||
pretty(circuit.statistics)
|
||||
|
||||
print("-" * columns)
|
||||
|
||||
print()
|
||||
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
# mypy: disable-error-code=attr-defined
|
||||
import concrete.compiler
|
||||
@@ -23,11 +23,12 @@ from concrete.compiler import (
|
||||
LibraryCompilationResult,
|
||||
LibraryLambda,
|
||||
LibrarySupport,
|
||||
Parameter,
|
||||
PublicArguments,
|
||||
set_compiler_logging,
|
||||
set_llvm_debug_flag,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
|
||||
from mlir._mlir_libs._concretelang._compiler import KeyType, OptimizerStrategy, PrimitiveOperation
|
||||
from mlir.ir import Module as MlirModule
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
@@ -435,66 +436,333 @@ class Server:
|
||||
"""
|
||||
return self._compilation_feedback.complexity
|
||||
|
||||
@property
|
||||
def total_pbs_count(self) -> int:
|
||||
"""
|
||||
Get the total number of programmable bootstraps in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_pbs_count
|
||||
# Programmable Bootstrap Statistics
|
||||
|
||||
@property
|
||||
def total_ks_count(self) -> int:
|
||||
def programmable_bootstrap_count(self) -> int:
|
||||
"""
|
||||
Get the total number of key switches in the compiled program.
|
||||
Get the number of programmable bootstraps in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_ks_count
|
||||
return self._compilation_feedback.count(
|
||||
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
|
||||
)
|
||||
|
||||
@property
|
||||
def total_clear_addition_count(self) -> int:
|
||||
def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the total number of clear additions in the compiled program.
|
||||
Get the number of programmable bootstraps per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_clear_addition_count
|
||||
return self._compilation_feedback.count_per_parameter(
|
||||
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
|
||||
key_types={KeyType.BOOTSTRAP},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
@property
|
||||
def total_encrypted_addition_count(self) -> int:
|
||||
def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the total number of encrypted additions in the compiled program.
|
||||
Get the number of programmable bootstraps per tag in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_encrypted_addition_count
|
||||
return self._compilation_feedback.count_per_tag(
|
||||
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
|
||||
)
|
||||
|
||||
@property
|
||||
def total_clear_multiplication_count(self) -> int:
|
||||
def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the total number of clear multiplications in the compiled program.
|
||||
Get the number of programmable bootstraps per tag per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_clear_multiplication_count
|
||||
return self._compilation_feedback.count_per_tag_per_parameter(
|
||||
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
|
||||
key_types={KeyType.BOOTSTRAP},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
# Key Switch Statistics
|
||||
|
||||
@property
|
||||
def total_encrypted_negation_count(self) -> int:
|
||||
def key_switch_count(self) -> int:
|
||||
"""
|
||||
Get the total number of encrypted negations in the compiled program.
|
||||
Get the number of key switches in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.total_encrypted_negation_count
|
||||
return self._compilation_feedback.count(
|
||||
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
|
||||
)
|
||||
|
||||
@property
|
||||
def statistics(self) -> dict:
|
||||
def key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get all program statistics in a dict.
|
||||
Get the number of key switches per parameter in the compiled program.
|
||||
"""
|
||||
return {
|
||||
"size_of_secret_keys": self.size_of_secret_keys,
|
||||
"size_of_bootstrap_keys": self.size_of_bootstrap_keys,
|
||||
"size_of_keyswitch_keys": self.size_of_keyswitch_keys,
|
||||
"size_of_inputs": self.size_of_inputs,
|
||||
"size_of_outputs": self.size_of_outputs,
|
||||
"p_error": self.p_error,
|
||||
"global_p_error": self.global_p_error,
|
||||
"complexity": self.complexity,
|
||||
"total_pbs_count": self.total_pbs_count,
|
||||
"total_ks_count": self.total_ks_count,
|
||||
"total_clear_addition_count": self.total_clear_addition_count,
|
||||
"total_encrypted_addition_count": self.total_encrypted_addition_count,
|
||||
"total_clear_multiplication_count": self.total_clear_multiplication_count,
|
||||
"total_encrypted_negation_count": self.total_encrypted_negation_count,
|
||||
}
|
||||
return self._compilation_feedback.count_per_parameter(
|
||||
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
|
||||
key_types={KeyType.KEY_SWITCH},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
@property
|
||||
def key_switch_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of key switches per tag in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag(
|
||||
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
|
||||
)
|
||||
|
||||
@property
|
||||
def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of key switches per tag per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag_per_parameter(
|
||||
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
|
||||
key_types={KeyType.KEY_SWITCH},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
# Packing Key Switch Statistics
|
||||
|
||||
@property
|
||||
def packing_key_switch_count(self) -> int:
|
||||
"""
|
||||
Get the number of packing key switches in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count(operations={PrimitiveOperation.WOP_PBS})
|
||||
|
||||
@property
|
||||
def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of packing key switches per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_parameter(
|
||||
operations={PrimitiveOperation.WOP_PBS},
|
||||
key_types={KeyType.PACKING_KEY_SWITCH},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
@property
|
||||
def packing_key_switch_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of packing key switches per tag in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag(operations={PrimitiveOperation.WOP_PBS})
|
||||
|
||||
@property
|
||||
def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of packing key switches per tag per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag_per_parameter(
|
||||
operations={PrimitiveOperation.WOP_PBS},
|
||||
key_types={KeyType.PACKING_KEY_SWITCH},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
# Clear Addition Statistics
|
||||
|
||||
@property
|
||||
def clear_addition_count(self) -> int:
|
||||
"""
|
||||
Get the number of clear additions in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count(operations={PrimitiveOperation.CLEAR_ADDITION})
|
||||
|
||||
@property
|
||||
def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of clear additions per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_parameter(
|
||||
operations={PrimitiveOperation.CLEAR_ADDITION},
|
||||
key_types={KeyType.SECRET},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
@property
|
||||
def clear_addition_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of clear additions per tag in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag(
|
||||
operations={PrimitiveOperation.CLEAR_ADDITION},
|
||||
)
|
||||
|
||||
@property
|
||||
def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of clear additions per tag per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag_per_parameter(
|
||||
operations={PrimitiveOperation.CLEAR_ADDITION},
|
||||
key_types={KeyType.SECRET},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
# Encrypted Addition Statistics
|
||||
|
||||
@property
|
||||
def encrypted_addition_count(self) -> int:
|
||||
"""
|
||||
Get the number of encrypted additions in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count(operations={PrimitiveOperation.ENCRYPTED_ADDITION})
|
||||
|
||||
@property
|
||||
def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of encrypted additions per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_parameter(
|
||||
operations={PrimitiveOperation.ENCRYPTED_ADDITION},
|
||||
key_types={KeyType.SECRET},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
@property
|
||||
def encrypted_addition_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of encrypted additions per tag in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag(
|
||||
operations={PrimitiveOperation.ENCRYPTED_ADDITION},
|
||||
)
|
||||
|
||||
@property
|
||||
def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of encrypted additions per tag per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag_per_parameter(
|
||||
operations={PrimitiveOperation.ENCRYPTED_ADDITION},
|
||||
key_types={KeyType.SECRET},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
# Clear Multiplication Statistics
|
||||
|
||||
@property
|
||||
def clear_multiplication_count(self) -> int:
|
||||
"""
|
||||
Get the number of clear multiplications in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count(
|
||||
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
|
||||
)
|
||||
|
||||
@property
|
||||
def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of clear multiplications per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_parameter(
|
||||
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
|
||||
key_types={KeyType.SECRET},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
@property
|
||||
def clear_multiplication_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of clear multiplications per tag in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag(
|
||||
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
|
||||
)
|
||||
|
||||
@property
|
||||
def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of clear multiplications per tag per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag_per_parameter(
|
||||
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
|
||||
key_types={KeyType.SECRET},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
# Encrypted Negation Statistics
|
||||
|
||||
@property
|
||||
def encrypted_negation_count(self) -> int:
|
||||
"""
|
||||
Get the number of encrypted negations in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count(operations={PrimitiveOperation.ENCRYPTED_NEGATION})
|
||||
|
||||
@property
|
||||
def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]:
|
||||
"""
|
||||
Get the number of encrypted negations per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_parameter(
|
||||
operations={PrimitiveOperation.ENCRYPTED_NEGATION},
|
||||
key_types={KeyType.SECRET},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
@property
|
||||
def encrypted_negation_count_per_tag(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the number of encrypted negations per tag in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag(
|
||||
operations={PrimitiveOperation.ENCRYPTED_NEGATION},
|
||||
)
|
||||
|
||||
@property
|
||||
def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
|
||||
"""
|
||||
Get the number of encrypted negations per tag per parameter in the compiled program.
|
||||
"""
|
||||
return self._compilation_feedback.count_per_tag_per_parameter(
|
||||
operations={PrimitiveOperation.ENCRYPTED_NEGATION},
|
||||
key_types={KeyType.SECRET},
|
||||
client_parameters=self.client_specs.client_parameters,
|
||||
)
|
||||
|
||||
# All Statistics
|
||||
|
||||
@property
|
||||
def statistics(self) -> Dict:
|
||||
"""
|
||||
Get all statistics of the compiled program.
|
||||
"""
|
||||
attributes = [
|
||||
"size_of_secret_keys",
|
||||
"size_of_bootstrap_keys",
|
||||
"size_of_keyswitch_keys",
|
||||
"size_of_inputs",
|
||||
"size_of_outputs",
|
||||
"p_error",
|
||||
"global_p_error",
|
||||
"complexity",
|
||||
"programmable_bootstrap_count",
|
||||
"programmable_bootstrap_count_per_parameter",
|
||||
"programmable_bootstrap_count_per_tag",
|
||||
"programmable_bootstrap_count_per_tag_per_parameter",
|
||||
"key_switch_count",
|
||||
"key_switch_count_per_parameter",
|
||||
"key_switch_count_per_tag",
|
||||
"key_switch_count_per_tag_per_parameter",
|
||||
"packing_key_switch_count",
|
||||
"packing_key_switch_count_per_parameter",
|
||||
"packing_key_switch_count_per_tag",
|
||||
"packing_key_switch_count_per_tag_per_parameter",
|
||||
"clear_addition_count",
|
||||
"clear_addition_count_per_parameter",
|
||||
"clear_addition_count_per_tag",
|
||||
"clear_addition_count_per_tag_per_parameter",
|
||||
"encrypted_addition_count",
|
||||
"encrypted_addition_count_per_parameter",
|
||||
"encrypted_addition_count_per_tag",
|
||||
"encrypted_addition_count_per_tag_per_parameter",
|
||||
"clear_multiplication_count",
|
||||
"clear_multiplication_count_per_parameter",
|
||||
"clear_multiplication_count_per_tag",
|
||||
"clear_multiplication_count_per_tag_per_parameter",
|
||||
"encrypted_negation_count",
|
||||
"encrypted_negation_count_per_parameter",
|
||||
"encrypted_negation_count_per_tag",
|
||||
"encrypted_negation_count_per_tag_per_parameter",
|
||||
]
|
||||
return {attribute: getattr(self, attribute) for attribute in attributes}
|
||||
|
||||
@@ -523,6 +523,20 @@ def test_circuit_compile_sim_only(helpers):
|
||||
assert f(*inputset[0]) == circuit.simulate(*inputset[0])
|
||||
|
||||
|
||||
def tagged_function(x, y, z):
|
||||
"""
|
||||
A tagged function to test statistics.
|
||||
"""
|
||||
with fhe.tag("a"):
|
||||
x = fhe.univariate(lambda v: v)(x)
|
||||
with fhe.tag("b"):
|
||||
y = fhe.univariate(lambda v: v)(y)
|
||||
with fhe.tag("c"):
|
||||
z = fhe.univariate(lambda v: v)(z)
|
||||
|
||||
return x + y + z
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_statistics",
|
||||
[
|
||||
@@ -532,12 +546,11 @@ def test_circuit_compile_sim_only(helpers):
|
||||
"x": {"status": "encrypted", "range": [0, 10], "shape": ()},
|
||||
},
|
||||
{
|
||||
"total_pbs_count": 1,
|
||||
"total_ks_count": 1,
|
||||
"total_clear_addition_count": 0,
|
||||
"total_encrypted_addition_count": 0,
|
||||
"total_clear_multiplication_count": 0,
|
||||
"total_encrypted_negation_count": 0,
|
||||
"programmable_bootstrap_count": 1,
|
||||
"clear_addition_count": 0,
|
||||
"encrypted_addition_count": 0,
|
||||
"clear_multiplication_count": 0,
|
||||
"encrypted_negation_count": 0,
|
||||
},
|
||||
id="x**2 | x.is_encrypted | x.shape == ()",
|
||||
),
|
||||
@@ -547,11 +560,11 @@ def test_circuit_compile_sim_only(helpers):
|
||||
"x": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"total_pbs_count": 3,
|
||||
"total_clear_addition_count": 0,
|
||||
"total_encrypted_addition_count": 0,
|
||||
"total_clear_multiplication_count": 0,
|
||||
"total_encrypted_negation_count": 0,
|
||||
"programmable_bootstrap_count": 3,
|
||||
"clear_addition_count": 0,
|
||||
"encrypted_addition_count": 0,
|
||||
"clear_multiplication_count": 0,
|
||||
"encrypted_negation_count": 0,
|
||||
},
|
||||
id="x**2 | x.is_encrypted | x.shape == (3,)",
|
||||
),
|
||||
@@ -561,11 +574,11 @@ def test_circuit_compile_sim_only(helpers):
|
||||
"x": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)},
|
||||
},
|
||||
{
|
||||
"total_pbs_count": 3 * 2,
|
||||
"total_clear_addition_count": 0,
|
||||
"total_encrypted_addition_count": 0,
|
||||
"total_clear_multiplication_count": 0,
|
||||
"total_encrypted_negation_count": 0,
|
||||
"programmable_bootstrap_count": 3 * 2,
|
||||
"clear_addition_count": 0,
|
||||
"encrypted_addition_count": 0,
|
||||
"clear_multiplication_count": 0,
|
||||
"encrypted_negation_count": 0,
|
||||
},
|
||||
id="x**2 | x.is_encrypted | x.shape == (3, 2)",
|
||||
),
|
||||
@@ -576,11 +589,11 @@ def test_circuit_compile_sim_only(helpers):
|
||||
"y": {"status": "encrypted", "range": [0, 10], "shape": ()},
|
||||
},
|
||||
{
|
||||
"total_pbs_count": 2,
|
||||
"total_clear_addition_count": 1,
|
||||
"total_encrypted_addition_count": 3,
|
||||
"total_clear_multiplication_count": 0,
|
||||
"total_encrypted_negation_count": 2,
|
||||
"programmable_bootstrap_count": 2,
|
||||
"clear_addition_count": 1,
|
||||
"encrypted_addition_count": 3,
|
||||
"clear_multiplication_count": 0,
|
||||
"encrypted_negation_count": 2,
|
||||
},
|
||||
id="x * y | x.is_encrypted | x.shape == () | y.is_encrypted | y.shape == ()",
|
||||
),
|
||||
@@ -591,11 +604,11 @@ def test_circuit_compile_sim_only(helpers):
|
||||
"y": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
|
||||
},
|
||||
{
|
||||
"total_pbs_count": 3 * 2,
|
||||
"total_clear_addition_count": 3 * 1,
|
||||
"total_encrypted_addition_count": 3 * 3,
|
||||
"total_clear_multiplication_count": 0,
|
||||
"total_encrypted_negation_count": 3 * 2,
|
||||
"programmable_bootstrap_count": 3 * 2,
|
||||
"clear_addition_count": 3 * 1,
|
||||
"encrypted_addition_count": 3 * 3,
|
||||
"clear_multiplication_count": 0,
|
||||
"encrypted_negation_count": 3 * 2,
|
||||
},
|
||||
id="x * y | x.is_encrypted | x.shape == (3,) | y.is_encrypted | y.shape == (3,)",
|
||||
),
|
||||
@@ -606,14 +619,30 @@ def test_circuit_compile_sim_only(helpers):
|
||||
"y": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)},
|
||||
},
|
||||
{
|
||||
"total_pbs_count": 3 * 2 * 2,
|
||||
"total_clear_addition_count": 3 * 2 * 1,
|
||||
"total_encrypted_addition_count": 3 * 2 * 3,
|
||||
"total_clear_multiplication_count": 0,
|
||||
"total_encrypted_negation_count": 3 * 2 * 2,
|
||||
"programmable_bootstrap_count": 3 * 2 * 2,
|
||||
"clear_addition_count": 3 * 2 * 1,
|
||||
"encrypted_addition_count": 3 * 2 * 3,
|
||||
"clear_multiplication_count": 0,
|
||||
"encrypted_negation_count": 3 * 2 * 2,
|
||||
},
|
||||
id="x * y | x.is_encrypted | x.shape == (3, 2) | y.is_encrypted | y.shape == (3, 2)",
|
||||
),
|
||||
pytest.param(
|
||||
tagged_function,
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 2**3 - 1], "shape": ()},
|
||||
"y": {"status": "encrypted", "range": [0, 2**4 - 1], "shape": ()},
|
||||
"z": {"status": "encrypted", "range": [0, 2**5 - 1], "shape": ()},
|
||||
},
|
||||
{
|
||||
"programmable_bootstrap_count_per_tag": {
|
||||
"a": 3,
|
||||
"a.b": 2,
|
||||
"a.b.c": 1,
|
||||
},
|
||||
},
|
||||
id="tagged_function",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_statistics(function, parameters, expected_statistics, helpers):
|
||||
|
||||
Reference in New Issue
Block a user