feat(compiler): add more detailed statistics

This commit is contained in:
Umut
2023-07-26 14:27:32 +02:00
parent dae31f0f26
commit ade83d5335
14 changed files with 1407 additions and 272 deletions

View File

@@ -18,6 +18,30 @@ namespace concretelang {
using StringError = ::concretelang::error::StringError;
enum class PrimitiveOperation {
PBS,
WOP_PBS,
KEY_SWITCH,
CLEAR_ADDITION,
ENCRYPTED_ADDITION,
CLEAR_MULTIPLICATION,
ENCRYPTED_NEGATION,
};
enum class KeyType {
SECRET,
BOOTSTRAP,
KEY_SWITCH,
PACKING_KEY_SWITCH,
};
struct Statistic {
std::string location;
PrimitiveOperation operation;
std::vector<std::pair<KeyType, size_t>> keys;
size_t count;
};
struct CompilationFeedback {
double complexity;
@@ -45,23 +69,8 @@ struct CompilationFeedback {
/// @brief crt decomposition of outputs, if crt is not used, empty vectors
std::vector<std::vector<int64_t>> crtDecompositionsOfOutputs;
/// @brief number of programmable bootstraps in the entire circuit
uint64_t totalPbsCount = 0;
/// @brief number of key switches in the entire circuit
uint64_t totalKsCount = 0;
/// @brief number of clear additions in the entire circuit
uint64_t totalClearAdditionCount = 0;
/// @brief number of encrypted additions in the entire circuit
uint64_t totalEncryptedAdditionCount = 0;
/// @brief number of clear multiplications in the entire circuit
uint64_t totalClearMultiplicationCount = 0;
/// @brief number of encrypted negations in the entire circuit
uint64_t totalEncryptedNegationCount = 0;
/// @brief statistics
std::vector<Statistic> statistics;
/// Fill the sizes from the client parameters.
void

View File

@@ -57,6 +57,7 @@ declare_mlir_python_sources(
concrete/compiler/library_compilation_result.py
concrete/compiler/library_support.py
concrete/compiler/library_lambda.py
concrete/compiler/parameter.py
concrete/compiler/public_arguments.py
concrete/compiler/public_result.py
concrete/compiler/evaluation_keys.py

View File

@@ -109,6 +109,35 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
options.simulate = simulate;
});
pybind11::enum_<mlir::concretelang::PrimitiveOperation>(m,
"PrimitiveOperation")
.value("PBS", mlir::concretelang::PrimitiveOperation::PBS)
.value("WOP_PBS", mlir::concretelang::PrimitiveOperation::WOP_PBS)
.value("KEY_SWITCH", mlir::concretelang::PrimitiveOperation::KEY_SWITCH)
.value("CLEAR_ADDITION",
mlir::concretelang::PrimitiveOperation::CLEAR_ADDITION)
.value("ENCRYPTED_ADDITION",
mlir::concretelang::PrimitiveOperation::ENCRYPTED_ADDITION)
.value("CLEAR_MULTIPLICATION",
mlir::concretelang::PrimitiveOperation::CLEAR_MULTIPLICATION)
.value("ENCRYPTED_NEGATION",
mlir::concretelang::PrimitiveOperation::ENCRYPTED_NEGATION)
.export_values();
pybind11::enum_<mlir::concretelang::KeyType>(m, "KeyType")
.value("SECRET", mlir::concretelang::KeyType::SECRET)
.value("BOOTSTRAP", mlir::concretelang::KeyType::BOOTSTRAP)
.value("KEY_SWITCH", mlir::concretelang::KeyType::KEY_SWITCH)
.value("PACKING_KEY_SWITCH",
mlir::concretelang::KeyType::PACKING_KEY_SWITCH)
.export_values();
pybind11::class_<mlir::concretelang::Statistic>(m, "Statistic")
.def_readonly("operation", &mlir::concretelang::Statistic::operation)
.def_readonly("location", &mlir::concretelang::Statistic::location)
.def_readonly("keys", &mlir::concretelang::Statistic::keys)
.def_readonly("count", &mlir::concretelang::Statistic::count);
pybind11::class_<mlir::concretelang::CompilationFeedback>(
m, "CompilationFeedback")
.def_readonly("complexity",
@@ -132,22 +161,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def_readonly(
"crt_decompositions_of_outputs",
&mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs)
.def_readonly("total_pbs_count",
&mlir::concretelang::CompilationFeedback::totalPbsCount)
.def_readonly("total_ks_count",
&mlir::concretelang::CompilationFeedback::totalKsCount)
.def_readonly(
"total_clear_addition_count",
&mlir::concretelang::CompilationFeedback::totalClearAdditionCount)
.def_readonly(
"total_encrypted_addition_count",
&mlir::concretelang::CompilationFeedback::totalEncryptedAdditionCount)
.def_readonly("total_clear_multiplication_count",
&mlir::concretelang::CompilationFeedback::
totalClearMultiplicationCount)
.def_readonly("total_encrypted_negation_count",
&mlir::concretelang::CompilationFeedback::
totalEncryptedNegationCount);
.def_readonly("statistics",
&mlir::concretelang::CompilationFeedback::statistics);
pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JITCompilationResult");
@@ -320,6 +335,76 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
pybind11::class_<clientlib::KeySetCache>(m, "KeySetCache")
.def(pybind11::init<std::string &>());
pybind11::class_<::concretelang::clientlib::LweSecretKeyParam>(
m, "LweSecretKeyParam")
.def_readonly("dimension",
&::concretelang::clientlib::LweSecretKeyParam::dimension);
pybind11::class_<::concretelang::clientlib::BootstrapKeyParam>(
m, "BootstrapKeyParam")
.def_readonly(
"input_secret_key_id",
&::concretelang::clientlib::BootstrapKeyParam::inputSecretKeyID)
.def_readonly(
"output_secret_key_id",
&::concretelang::clientlib::BootstrapKeyParam::outputSecretKeyID)
.def_readonly("level",
&::concretelang::clientlib::BootstrapKeyParam::level)
.def_readonly("base_log",
&::concretelang::clientlib::BootstrapKeyParam::baseLog)
.def_readonly(
"glwe_dimension",
&::concretelang::clientlib::BootstrapKeyParam::glweDimension)
.def_readonly("variance",
&::concretelang::clientlib::BootstrapKeyParam::variance)
.def_readonly(
"polynomial_size",
&::concretelang::clientlib::BootstrapKeyParam::polynomialSize)
.def_readonly(
"input_lwe_dimension",
&::concretelang::clientlib::BootstrapKeyParam::inputLweDimension);
pybind11::class_<::concretelang::clientlib::KeyswitchKeyParam>(
m, "KeyswitchKeyParam")
.def_readonly(
"input_secret_key_id",
&::concretelang::clientlib::KeyswitchKeyParam::inputSecretKeyID)
.def_readonly(
"output_secret_key_id",
&::concretelang::clientlib::KeyswitchKeyParam::outputSecretKeyID)
.def_readonly("level",
&::concretelang::clientlib::KeyswitchKeyParam::level)
.def_readonly("base_log",
&::concretelang::clientlib::KeyswitchKeyParam::baseLog)
.def_readonly("variance",
&::concretelang::clientlib::KeyswitchKeyParam::variance);
pybind11::class_<::concretelang::clientlib::PackingKeyswitchKeyParam>(
m, "PackingKeyswitchKeyParam")
.def_readonly("input_secret_key_id",
&::concretelang::clientlib::PackingKeyswitchKeyParam::
inputSecretKeyID)
.def_readonly("output_secret_key_id",
&::concretelang::clientlib::PackingKeyswitchKeyParam::
outputSecretKeyID)
.def_readonly("level",
&::concretelang::clientlib::PackingKeyswitchKeyParam::level)
.def_readonly(
"base_log",
&::concretelang::clientlib::PackingKeyswitchKeyParam::baseLog)
.def_readonly(
"glwe_dimension",
&::concretelang::clientlib::PackingKeyswitchKeyParam::glweDimension)
.def_readonly(
"polynomial_size",
&::concretelang::clientlib::PackingKeyswitchKeyParam::polynomialSize)
.def_readonly("input_lwe_dimension",
&::concretelang::clientlib::PackingKeyswitchKeyParam::
inputLweDimension)
.def_readonly(
"variance",
&::concretelang::clientlib::PackingKeyswitchKeyParam::variance);
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters")
.def_static("deserialize",
[](const pybind11::bytes &buffer) {
@@ -353,7 +438,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
}
}
return result;
});
})
.def_readonly("secret_keys",
&mlir::concretelang::ClientParameters::secretKeys)
.def_readonly("bootstrap_keys",
&mlir::concretelang::ClientParameters::bootstrapKeys)
.def_readonly("keyswitch_keys",
&mlir::concretelang::ClientParameters::keyswitchKeys)
.def_readonly(
"packing_keyswitch_keys",
&mlir::concretelang::ClientParameters::packingKeyswitchKeys);
pybind11::class_<clientlib::KeySet>(m, "KeySet")
.def_static("deserialize",

View File

@@ -39,6 +39,7 @@ from .value_decrypter import ValueDecrypter
from .value_exporter import ValueExporter
from .simulated_value_decrypter import SimulatedValueDecrypter
from .simulated_value_exporter import SimulatedValueExporter
from .parameter import Parameter
def init_dfr():

View File

@@ -3,13 +3,19 @@
"""Compilation feedback."""
from typing import Dict, Set
# pylint: disable=no-name-in-module,import-error,too-many-instance-attributes
from mlir._mlir_libs._concretelang._compiler import (
CompilationFeedback as _CompilationFeedback,
KeyType,
PrimitiveOperation,
)
# pylint: enable=no-name-in-module,import-error
from .client_parameters import ClientParameters
from .parameter import Parameter
from .wrapper import WrapperCpp
@@ -41,19 +47,153 @@ class CompilationFeedback(WrapperCpp):
self.crt_decompositions_of_outputs = (
compilation_feedback.crt_decompositions_of_outputs
)
self.total_pbs_count = compilation_feedback.total_pbs_count
self.total_ks_count = compilation_feedback.total_ks_count
self.total_clear_addition_count = (
compilation_feedback.total_clear_addition_count
)
self.total_encrypted_addition_count = (
compilation_feedback.total_encrypted_addition_count
)
self.total_clear_multiplication_count = (
compilation_feedback.total_clear_multiplication_count
)
self.total_encrypted_negation_count = (
compilation_feedback.total_encrypted_negation_count
)
self.statistics = compilation_feedback.statistics
super().__init__(compilation_feedback)
def count(self, *, operations: Set[PrimitiveOperation]) -> int:
"""
Count the amount of specified operations in the program.
Args:
operations (Set[PrimitiveOperation]):
set of operations used to filter the statistics
Returns:
int:
number of specified operations in the program
"""
return sum(
statistic.count
for statistic in self.statistics
if statistic.operation in operations
)
def count_per_parameter(
self,
*,
operations: Set[PrimitiveOperation],
key_types: Set[KeyType],
client_parameters: ClientParameters,
) -> Dict[Parameter, int]:
"""
Count the amount of specified operations in the program and group by parameters.
Args:
operations (Set[PrimitiveOperation]):
set of operations used to filter the statistics
key_types (Set[KeyType]):
set of key types used to filter the statistics
client_parameters (ClientParameters):
client parameters required for grouping by parameters
Returns:
Dict[Parameter, int]:
number of specified operations per parameter in the program
"""
result = {}
for statistic in self.statistics:
if statistic.operation not in operations:
continue
for key_type, key_index in statistic.keys:
if key_type not in key_types:
continue
parameter = Parameter(client_parameters, key_type, key_index)
if parameter not in result:
result[parameter] = 0
result[parameter] += statistic.count
return result
def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int]:
"""
Count the amount of specified operations in the program and group by tags.
Args:
operations (Set[PrimitiveOperation]):
set of operations used to filter the statistics
Returns:
Dict[str, int]:
number of specified operations per tag in the program
"""
result = {}
for statistic in self.statistics:
if statistic.operation not in operations:
continue
file_and_maybe_tag = statistic.location.split("@")
tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip()
tag_components = tag.split(".")
for i in range(1, len(tag_components) + 1):
current_tag = ".".join(tag_components[0:i])
if current_tag == "":
continue
if current_tag not in result:
result[current_tag] = 0
result[current_tag] += statistic.count
return result
def count_per_tag_per_parameter(
self,
*,
operations: Set[PrimitiveOperation],
key_types: Set[KeyType],
client_parameters: ClientParameters,
) -> Dict[str, Dict[Parameter, int]]:
"""
Count the amount of specified operations in the program and group by tags and parameters.
Args:
operations (Set[PrimitiveOperation]):
set of operations used to filter the statistics
key_types (Set[KeyType]):
set of key types used to filter the statistics
client_parameters (ClientParameters):
client parameters required for grouping by parameters
Returns:
Dict[str, Dict[Parameter, int]]:
number of specified operations per tag per parameter in the program
"""
result: Dict[str, Dict[int, int]] = {}
for statistic in self.statistics:
if statistic.operation not in operations:
continue
file_and_maybe_tag = statistic.location.split("@")
tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip()
tag_components = tag.split(".")
for i in range(1, len(tag_components) + 1):
current_tag = ".".join(tag_components[0:i])
if current_tag == "":
continue
if current_tag not in result:
result[current_tag] = {}
for key_type, key_index in statistic.keys:
if key_type not in key_types:
continue
parameter = Parameter(client_parameters, key_type, key_index)
if parameter not in result[current_tag]:
result[current_tag][parameter] = 0
result[current_tag][parameter] += statistic.count
return result

View File

@@ -0,0 +1,105 @@
"""
Parameter.
"""
# pylint: disable=no-name-in-module,import-error
from typing import Union
from mlir._mlir_libs._concretelang._compiler import (
LweSecretKeyParam,
BootstrapKeyParam,
KeyswitchKeyParam,
PackingKeyswitchKeyParam,
KeyType,
)
from .client_parameters import ClientParameters
# pylint: enable=no-name-in-module,import-error
class Parameter:
"""
An FHE parameter.
"""
_inner: Union[
LweSecretKeyParam,
BootstrapKeyParam,
KeyswitchKeyParam,
PackingKeyswitchKeyParam,
]
def __init__(
self,
client_parameters: ClientParameters,
key_type: KeyType,
key_index: int,
):
if key_type == KeyType.SECRET:
self._inner = client_parameters.cpp().secret_keys[key_index]
elif key_type == KeyType.BOOTSTRAP:
self._inner = client_parameters.cpp().bootstrap_keys[key_index]
elif key_type == KeyType.KEY_SWITCH:
self._inner = client_parameters.cpp().keyswitch_keys[key_index]
elif key_type == KeyType.PACKING_KEY_SWITCH:
self._inner = client_parameters.cpp().packing_keyswitch_keys[key_index]
else:
raise ValueError("invalid key type")
def __getattr__(self, item):
return getattr(self._inner, item)
def __repr__(self):
param = self._inner
if isinstance(param, LweSecretKeyParam):
result = f"LweSecretKeyParam(" f"dimension={param.dimension}" f")"
elif isinstance(param, BootstrapKeyParam):
result = (
f"BootstrapKeyParam("
f"polynomial_size={param.polynomial_size}, "
f"glwe_dimension={param.glwe_dimension}, "
f"input_lwe_dimension={param.input_lwe_dimension}, "
f"level={param.level}, "
f"base_log={param.base_log}, "
f"variance={param.variance}"
f")"
)
elif isinstance(param, KeyswitchKeyParam):
result = (
f"KeyswitchKeyParam("
f"level={param.level}, "
f"base_log={param.base_log}, "
f"variance={param.variance}"
f")"
)
elif isinstance(param, PackingKeyswitchKeyParam):
result = (
f"PackingKeyswitchKeyParam("
f"polynomial_size={param.polynomial_size}, "
f"glwe_dimension={param.glwe_dimension}, "
f"input_lwe_dimension={param.input_lwe_dimension}"
f"level={param.level}, "
f"base_log={param.base_log}, "
f"variance={param.variance}"
f")"
)
else:
assert False
return result
def __str__(self):
return repr(self)
def __hash__(self):
return hash(str(self))
def __eq__(self, other):
return str(self) == str(other)

View File

@@ -301,6 +301,7 @@ const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[
"SDFGDialect",
"ExtractSDFGOps",
"SDFGToStreamEmulator",
"TFHEDialectAnalysis",
];
fn main() {

View File

@@ -10,6 +10,17 @@ using namespace mlir;
using TFHE::ExtractTFHEStatisticsPass;
// #########
// Utilities
// #########
template <typename Op> std::string locationOf(Op op) {
auto location = std::string();
auto locationStream = llvm::raw_string_ostream(location);
op.getLoc()->print(locationStream);
return location.substr(5, location.size() - 2 - 5); // remove loc(" and ")
}
// #######
// scf.for
// #######
@@ -103,7 +114,24 @@ static std::optional<StringError> on_exit(scf::ForOp &op,
static std::optional<StringError> on_enter(TFHE::AddGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
pass.feedback.totalEncryptedAdditionCount += pass.iterations;
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
@@ -113,7 +141,24 @@ static std::optional<StringError> on_enter(TFHE::AddGLWEOp &op,
static std::optional<StringError> on_enter(TFHE::AddGLWEIntOp &op,
ExtractTFHEStatisticsPass &pass) {
pass.feedback.totalClearAdditionCount += pass.iterations;
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto operation = PrimitiveOperation::CLEAR_ADDITION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
@@ -123,7 +168,24 @@ static std::optional<StringError> on_enter(TFHE::AddGLWEIntOp &op,
static std::optional<StringError> on_enter(TFHE::BootstrapGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
pass.feedback.totalPbsCount += pass.iterations;
auto bsk = op.getKey();
auto location = locationOf(op);
auto operation = PrimitiveOperation::PBS;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
@@ -133,7 +195,24 @@ static std::optional<StringError> on_enter(TFHE::BootstrapGLWEOp &op,
static std::optional<StringError> on_enter(TFHE::KeySwitchGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
pass.feedback.totalKsCount += pass.iterations;
auto ksk = op.getKey();
auto location = locationOf(op);
auto operation = PrimitiveOperation::KEY_SWITCH;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
@@ -143,7 +222,24 @@ static std::optional<StringError> on_enter(TFHE::KeySwitchGLWEOp &op,
static std::optional<StringError> on_enter(TFHE::MulGLWEIntOp &op,
ExtractTFHEStatisticsPass &pass) {
pass.feedback.totalClearMultiplicationCount += pass.iterations;
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
@@ -153,7 +249,24 @@ static std::optional<StringError> on_enter(TFHE::MulGLWEIntOp &op,
static std::optional<StringError> on_enter(TFHE::NegGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
pass.feedback.totalEncryptedNegationCount += pass.iterations;
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
@@ -163,9 +276,34 @@ static std::optional<StringError> on_enter(TFHE::NegGLWEOp &op,
static std::optional<StringError> on_enter(TFHE::SubGLWEIntOp &op,
ExtractTFHEStatisticsPass &pass) {
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationOf(op);
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
keys.push_back(key);
// clear - encrypted = clear + neg(encrypted)
pass.feedback.totalEncryptedNegationCount += pass.iterations;
pass.feedback.totalClearAdditionCount += pass.iterations;
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
operation = PrimitiveOperation::CLEAR_ADDITION;
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}
@@ -175,7 +313,32 @@ static std::optional<StringError> on_enter(TFHE::SubGLWEIntOp &op,
static std::optional<StringError> on_enter(TFHE::WopPBSGLWEOp &op,
ExtractTFHEStatisticsPass &pass) {
pass.feedback.totalPbsCount += pass.iterations;
auto bsk = op.getBsk();
auto ksk = op.getKsk();
auto pksk = op.getPksk();
auto location = locationOf(op);
auto operation = PrimitiveOperation::WOP_PBS;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
keys.push_back(key);
key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
keys.push_back(key);
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(Statistic{
location,
operation,
keys,
count,
});
return std::nullopt;
}

View File

@@ -6,7 +6,6 @@
#include <cassert>
#include <fstream>
#include "boost/outcome.h"
#include "llvm/Support/JSON.h"
#include "concretelang/Support/CompilationFeedback.h"
@@ -62,13 +61,6 @@ void CompilationFeedback::fillFromClientParameters(
}
crtDecompositionsOfOutputs.push_back(decomposition);
}
// Stats
totalPbsCount = 0;
totalKsCount = 0;
totalClearAdditionCount = 0;
totalEncryptedAdditionCount = 0;
totalClearMultiplicationCount = 0;
totalEncryptedNegationCount = 0;
}
outcome::checked<CompilationFeedback, StringError>
@@ -81,7 +73,7 @@ CompilationFeedback::load(std::string jsonPath) {
}
auto expectedCompFeedback = llvm::json::parse<CompilationFeedback>(content);
if (auto err = expectedCompFeedback.takeError()) {
return StringError("Cannot open client parameters: ")
return StringError("Cannot open compilation feedback: ")
<< llvm::toString(std::move(err)) << "\n"
<< content << "\n";
}
@@ -99,34 +91,166 @@ llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) {
{"totalInputsSize", v.totalInputsSize},
{"totalOutputsSize", v.totalOutputsSize},
{"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs},
{"totalPbsCount", v.totalPbsCount},
{"totalKsCount", v.totalKsCount},
{"totalClearAdditionCount", v.totalClearAdditionCount},
{"totalEncryptedAdditionCount", v.totalEncryptedAdditionCount},
{"totalClearMultiplicationCount", v.totalClearMultiplicationCount},
{"totalEncryptedNegationCount", v.totalEncryptedNegationCount},
};
auto statisticsJson = llvm::json::Array();
for (auto statistic : v.statistics) {
auto statisticJson = llvm::json::Object();
statisticJson.insert({"location", statistic.location});
switch (statistic.operation) {
case PrimitiveOperation::PBS:
statisticJson.insert({"operation", "PBS"});
break;
case PrimitiveOperation::WOP_PBS:
statisticJson.insert({"operation", "WOP_PBS"});
break;
case PrimitiveOperation::KEY_SWITCH:
statisticJson.insert({"operation", "KEY_SWITCH"});
break;
case PrimitiveOperation::CLEAR_ADDITION:
statisticJson.insert({"operation", "CLEAR_ADDITION"});
break;
case PrimitiveOperation::ENCRYPTED_ADDITION:
statisticJson.insert({"operation", "ENCRYPTED_ADDITION"});
break;
case PrimitiveOperation::CLEAR_MULTIPLICATION:
statisticJson.insert({"operation", "CLEAR_MULTIPLICATION"});
break;
case PrimitiveOperation::ENCRYPTED_NEGATION:
statisticJson.insert({"operation", "ENCRYPTED_NEGATION"});
break;
}
auto keysJson = llvm::json::Array();
for (auto &key : statistic.keys) {
KeyType type = key.first;
size_t index = key.second;
auto keyJson = llvm::json::Array();
switch (type) {
case KeyType::SECRET:
keyJson.push_back("SECRET");
break;
case KeyType::BOOTSTRAP:
keyJson.push_back("BOOTSTRAP");
break;
case KeyType::KEY_SWITCH:
keyJson.push_back("KEY_SWITCH");
break;
case KeyType::PACKING_KEY_SWITCH:
keyJson.push_back("PACKING_KEY_SWITCH");
break;
}
keyJson.push_back((int64_t)index);
keysJson.push_back(std::move(keyJson));
}
statisticJson.insert({"keys", std::move(keysJson)});
statisticJson.insert({"count", (int64_t)statistic.count});
statisticsJson.push_back(std::move(statisticJson));
}
object.insert({"statistics", std::move(statisticsJson)});
return object;
}
bool fromJSON(const llvm::json::Value j,
mlir::concretelang::CompilationFeedback &v, llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("complexity", v.complexity) && O.map("pError", v.pError) &&
O.map("globalPError", v.globalPError) &&
O.map("totalSecretKeysSize", v.totalSecretKeysSize) &&
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
O.map("totalInputsSize", v.totalInputsSize) &&
O.map("totalOutputsSize", v.totalOutputsSize) &&
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs) &&
O.map("totalPbsCount", v.totalPbsCount) &&
O.map("totalKsCount", v.totalKsCount) &&
O.map("totalClearAdditionCount", v.totalClearAdditionCount) &&
O.map("totalEncryptedAdditionCount", v.totalEncryptedAdditionCount) &&
O.map("totalClearMultiplicationCount",
v.totalClearMultiplicationCount) &&
O.map("totalEncryptedNegationCount", v.totalEncryptedNegationCount);
bool is_success =
O && O.map("complexity", v.complexity) && O.map("pError", v.pError) &&
O.map("globalPError", v.globalPError) &&
O.map("totalSecretKeysSize", v.totalSecretKeysSize) &&
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
O.map("totalInputsSize", v.totalInputsSize) &&
O.map("totalOutputsSize", v.totalOutputsSize) &&
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs);
if (!is_success) {
return false;
}
auto object = j.getAsObject();
if (!object) {
return false;
}
auto statistics = object->getArray("statistics");
if (!statistics) {
return false;
}
for (auto statisticValue : *statistics) {
auto statistic = statisticValue.getAsObject();
if (!statistic) {
return false;
}
auto location = statistic->getString("location");
auto operationStr = statistic->getString("operation");
auto keysArray = statistic->getArray("keys");
auto count = statistic->getInteger("count");
if (!operationStr || !location || !keysArray || !count) {
return false;
}
PrimitiveOperation operation;
if (operationStr.value() == "PBS") {
operation = PrimitiveOperation::PBS;
} else if (operationStr.value() == "KEY_SWITCH") {
operation = PrimitiveOperation::KEY_SWITCH;
} else if (operationStr.value() == "WOP_PBS") {
operation = PrimitiveOperation::WOP_PBS;
} else if (operationStr.value() == "CLEAR_ADDITION") {
operation = PrimitiveOperation::CLEAR_ADDITION;
} else if (operationStr.value() == "ENCRYPTED_ADDITION") {
operation = PrimitiveOperation::ENCRYPTED_ADDITION;
} else if (operationStr.value() == "CLEAR_MULTIPLICATION") {
operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
} else if (operationStr.value() == "ENCRYPTED_NEGATION") {
operation = PrimitiveOperation::ENCRYPTED_NEGATION;
} else {
return false;
}
auto keys = std::vector<std::pair<KeyType, size_t>>();
for (auto keyValue : *keysArray) {
llvm::json::Array *keyArray = keyValue.getAsArray();
if (!keyArray || keyArray->size() != 2) {
return false;
}
auto typeStr = keyArray->front().getAsString();
auto index = keyArray->back().getAsInteger();
if (!typeStr || !index) {
return false;
}
KeyType type;
if (typeStr.value() == "SECRET") {
type = KeyType::SECRET;
} else if (typeStr.value() == "BOOTSTRAP") {
type = KeyType::BOOTSTRAP;
} else if (typeStr.value() == "KEY_SWITCH") {
type = KeyType::KEY_SWITCH;
} else if (typeStr.value() == "PACKING_KEY_SWITCH") {
type = KeyType::PACKING_KEY_SWITCH;
} else {
return false;
}
keys.push_back(std::make_pair(type, (size_t)*index));
}
v.statistics.push_back(
Statistic{location->str(), operation, keys, (uint64_t)*count});
}
return true;
}
} // namespace concretelang

View File

@@ -4,7 +4,7 @@ Concrete.
# pylint: disable=import-error,no-name-in-module
from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult
from concrete.compiler import EvaluationKeys, Parameter, PublicArguments, PublicResult
from .compilation import (
DEFAULT_GLOBAL_P_ERROR,

View File

@@ -4,10 +4,15 @@ Declaration of `Circuit` class.
# pylint: disable=import-error,no-member,no-name-in-module
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from concrete.compiler import CompilationContext, SimulatedValueDecrypter, SimulatedValueExporter
from concrete.compiler import (
CompilationContext,
Parameter,
SimulatedValueDecrypter,
SimulatedValueExporter,
)
from mlir.ir import Module as MlirModule
from ..internal.utils import assert_that
@@ -266,118 +271,15 @@ class Circuit:
if hasattr(self, "server"): # pragma: no cover
self.server.cleanup()
@property
def size_of_secret_keys(self) -> int:
"""
Get size of the secret keys of the circuit.
"""
return self._statistic("size_of_secret_keys")
# Properties
@property
def size_of_bootstrap_keys(self) -> int:
def _property(self, name: str) -> Any:
"""
Get size of the bootstrap keys of the circuit.
"""
return self._statistic("size_of_bootstrap_keys")
@property
def size_of_keyswitch_keys(self) -> int:
"""
Get size of the key switch keys of the circuit.
"""
return self._statistic("size_of_keyswitch_keys")
@property
def size_of_inputs(self) -> int:
"""
Get size of the inputs of the circuit.
"""
return self._statistic("size_of_inputs")
@property
def size_of_outputs(self) -> int:
"""
Get size of the outputs of the circuit.
"""
return self._statistic("size_of_outputs")
@property
def p_error(self) -> int:
"""
Get probability of error for each simple TLU (on a scalar).
"""
return self._statistic("p_error")
@property
def global_p_error(self) -> int:
"""
Get the probability of having at least one simple TLU error during the entire execution.
"""
return self._statistic("global_p_error")
@property
def complexity(self) -> float:
"""
Get complexity of the circuit.
"""
return self._statistic("complexity")
@property
def total_pbs_count(self) -> int:
"""
Get the total number of programmable bootstraps in the circuit.
"""
return self._statistic("total_pbs_count")
@property
def total_ks_count(self) -> int:
"""
Get the total number of key switches in the circuit.
"""
return self._statistic("total_ks_count")
@property
def total_clear_addition_count(self) -> int:
"""
Get the total number of clear additions in the circuit.
"""
return self._statistic("total_clear_addition_count")
@property
def total_encrypted_addition_count(self) -> int:
"""
Get the total number of encrypted additions in the circuit.
"""
return self._statistic("total_encrypted_addition_count")
@property
def total_clear_multiplication_count(self) -> int:
"""
Get the total number of clear multiplications in the circuit.
"""
return self._statistic("total_clear_multiplication_count")
@property
def total_encrypted_negation_count(self) -> int:
"""
Get the total number of encrypted negations in the circuit.
"""
return self._statistic("total_encrypted_negation_count")
@property
def statistics(self) -> dict:
"""
Get all circuit statistics in a dict.
"""
return self._statistic("statistics")
def _statistic(self, name: str) -> Any:
"""
Get a statistic of the circuit by name.
Get a property of the circuit by name.
Args:
name (str):
name of the statistic
name of the property
Returns:
Any:
@@ -391,3 +293,282 @@ class Circuit:
self.enable_fhe_execution() # pragma: no cover
return getattr(self.server, name)
@property
def size_of_secret_keys(self) -> int:
"""
Get size of the secret keys of the circuit.
"""
return self._property("size_of_secret_keys") # pragma: no cover
@property
def size_of_bootstrap_keys(self) -> int:
"""
Get size of the bootstrap keys of the circuit.
"""
return self._property("size_of_bootstrap_keys") # pragma: no cover
@property
def size_of_keyswitch_keys(self) -> int:
"""
Get size of the key switch keys of the circuit.
"""
return self._property("size_of_keyswitch_keys") # pragma: no cover
@property
def size_of_inputs(self) -> int:
"""
Get size of the inputs of the circuit.
"""
return self._property("size_of_inputs") # pragma: no cover
@property
def size_of_outputs(self) -> int:
"""
Get size of the outputs of the circuit.
"""
return self._property("size_of_outputs") # pragma: no cover
@property
def p_error(self) -> int:
"""
Get probability of error for each simple TLU (on a scalar).
"""
return self._property("p_error") # pragma: no cover
@property
def global_p_error(self) -> int:
"""
Get the probability of having at least one simple TLU error during the entire execution.
"""
return self._property("global_p_error") # pragma: no cover
@property
def complexity(self) -> float:
"""
Get complexity of the circuit.
"""
return self._property("complexity") # pragma: no cover
# Programmable Bootstrap Statistics
@property
def programmable_bootstrap_count(self) -> int:
"""
Get the number of programmable bootstraps in the circuit.
"""
return self._property("programmable_bootstrap_count") # pragma: no cover
@property
def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of programmable bootstraps per bit width in the circuit.
"""
return self._property("programmable_bootstrap_count_per_parameter") # pragma: no cover
@property
def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of programmable bootstraps per tag in the circuit.
"""
return self._property("programmable_bootstrap_count_per_tag") # pragma: no cover
@property
def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[int, int]]:
"""
Get the number of programmable bootstraps per tag per bit width in the circuit.
"""
return self._property(
"programmable_bootstrap_count_per_tag_per_parameter"
) # pragma: no cover
# Key Switch Statistics
@property
def key_switch_count(self) -> int:
"""
Get the number of key switches in the circuit.
"""
return self._property("key_switch_count") # pragma: no cover
@property
def key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of key switches per parameter in the circuit.
"""
return self._property("key_switch_count_per_parameter") # pragma: no cover
@property
def key_switch_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of key switches per tag in the circuit.
"""
return self._property("key_switch_count_per_tag") # pragma: no cover
@property
def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of key switches per tag per parameter in the circuit.
"""
return self._property("key_switch_count_per_tag_per_parameter") # pragma: no cover
# Packing Key Switch Statistics
@property
def packing_key_switch_count(self) -> int:
"""
Get the number of packing key switches in the circuit.
"""
return self._property("packing_key_switch_count") # pragma: no cover
@property
def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of packing key switches per parameter in the circuit.
"""
return self._property("packing_key_switch_count_per_parameter") # pragma: no cover
@property
def packing_key_switch_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of packing key switches per tag in the circuit.
"""
return self._property("packing_key_switch_count_per_tag") # pragma: no cover
@property
def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of packing key switches per tag per parameter in the circuit.
"""
return self._property("packing_key_switch_count_per_tag_per_parameter") # pragma: no cover
# Clear Addition Statistics
@property
def clear_addition_count(self) -> int:
"""
Get the number of clear additions in the circuit.
"""
return self._property("clear_addition_count") # pragma: no cover
@property
def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of clear additions per parameter in the circuit.
"""
return self._property("clear_addition_count_per_parameter") # pragma: no cover
@property
def clear_addition_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of clear additions per tag in the circuit.
"""
return self._property("clear_addition_count_per_tag") # pragma: no cover
@property
def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of clear additions per tag per parameter in the circuit.
"""
return self._property("clear_addition_count_per_tag_per_parameter") # pragma: no cover
# Encrypted Addition Statistics
@property
def encrypted_addition_count(self) -> int:
"""
Get the number of encrypted additions in the circuit.
"""
return self._property("encrypted_addition_count") # pragma: no cover
@property
def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of encrypted additions per parameter in the circuit.
"""
return self._property("encrypted_addition_count_per_parameter") # pragma: no cover
@property
def encrypted_addition_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of encrypted additions per tag in the circuit.
"""
return self._property("encrypted_addition_count_per_tag") # pragma: no cover
@property
def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of encrypted additions per tag per parameter in the circuit.
"""
return self._property("encrypted_addition_count_per_tag_per_parameter") # pragma: no cover
# Clear Multiplication Statistics
@property
def clear_multiplication_count(self) -> int:
"""
Get the number of clear multiplications in the circuit.
"""
return self._property("clear_multiplication_count") # pragma: no cover
@property
def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of clear multiplications per parameter in the circuit.
"""
return self._property("clear_multiplication_count_per_parameter") # pragma: no cover
@property
def clear_multiplication_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of clear multiplications per tag in the circuit.
"""
return self._property("clear_multiplication_count_per_tag") # pragma: no cover
@property
def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of clear multiplications per tag per parameter in the circuit.
"""
return self._property(
"clear_multiplication_count_per_tag_per_parameter"
) # pragma: no cover
# Encrypted Negation Statistics
@property
def encrypted_negation_count(self) -> int:
"""
Get the number of encrypted negations in the circuit.
"""
return self._property("encrypted_negation_count") # pragma: no cover
@property
def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of encrypted negations per parameter in the circuit.
"""
return self._property("encrypted_negation_count_per_parameter") # pragma: no cover
@property
def encrypted_negation_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of encrypted negations per tag in the circuit.
"""
return self._property("encrypted_negation_count_per_tag") # pragma: no cover
@property
def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of encrypted negations per tag per parameter in the circuit.
"""
return self._property("encrypted_negation_count_per_tag_per_parameter") # pragma: no cover
# All Statistics
@property
def statistics(self) -> Dict:
"""
Get all statistics of the circuit.
"""
return self._property("statistics") # pragma: no cover

View File

@@ -485,7 +485,7 @@ class Compiler:
)
columns = 0
if show_graph or show_mlir or show_optimizer:
if show_graph or show_mlir or show_optimizer or show_statistics:
graph = (
self.graph.format()
if self.configuration.verbose or self.configuration.show_graph
@@ -556,8 +556,27 @@ class Compiler:
print("Statistics")
print("-" * columns)
for name, value in circuit.statistics.items():
print(f"{name}: {value}")
def pretty(d, indent=0): # pragma: no cover
if indent > 0:
print("{")
for key, value in d.items():
if isinstance(value, dict) and len(value) == 0:
continue
print(" " * indent + str(key) + ": ", end="")
if isinstance(value, dict):
pretty(value, indent + 1)
else:
print(value)
if indent > 0:
print(" " * (indent - 1) + "}")
pretty(circuit.statistics)
print("-" * columns)
print()

View File

@@ -8,7 +8,7 @@ import json
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
# mypy: disable-error-code=attr-defined
import concrete.compiler
@@ -23,11 +23,12 @@ from concrete.compiler import (
LibraryCompilationResult,
LibraryLambda,
LibrarySupport,
Parameter,
PublicArguments,
set_compiler_logging,
set_llvm_debug_flag,
)
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
from mlir._mlir_libs._concretelang._compiler import KeyType, OptimizerStrategy, PrimitiveOperation
from mlir.ir import Module as MlirModule
from ..internal.utils import assert_that
@@ -435,66 +436,333 @@ class Server:
"""
return self._compilation_feedback.complexity
@property
def total_pbs_count(self) -> int:
"""
Get the total number of programmable bootstraps in the compiled program.
"""
return self._compilation_feedback.total_pbs_count
# Programmable Bootstrap Statistics
@property
def total_ks_count(self) -> int:
def programmable_bootstrap_count(self) -> int:
"""
Get the total number of key switches in the compiled program.
Get the number of programmable bootstraps in the compiled program.
"""
return self._compilation_feedback.total_ks_count
return self._compilation_feedback.count(
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
)
@property
def total_clear_addition_count(self) -> int:
def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the total number of clear additions in the compiled program.
Get the number of programmable bootstraps per parameter in the compiled program.
"""
return self._compilation_feedback.total_clear_addition_count
return self._compilation_feedback.count_per_parameter(
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
key_types={KeyType.BOOTSTRAP},
client_parameters=self.client_specs.client_parameters,
)
@property
def total_encrypted_addition_count(self) -> int:
def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]:
"""
Get the total number of encrypted additions in the compiled program.
Get the number of programmable bootstraps per tag in the compiled program.
"""
return self._compilation_feedback.total_encrypted_addition_count
return self._compilation_feedback.count_per_tag(
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
)
@property
def total_clear_multiplication_count(self) -> int:
def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the total number of clear multiplications in the compiled program.
Get the number of programmable bootstraps per tag per parameter in the compiled program.
"""
return self._compilation_feedback.total_clear_multiplication_count
return self._compilation_feedback.count_per_tag_per_parameter(
operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS},
key_types={KeyType.BOOTSTRAP},
client_parameters=self.client_specs.client_parameters,
)
# Key Switch Statistics
@property
def total_encrypted_negation_count(self) -> int:
def key_switch_count(self) -> int:
"""
Get the total number of encrypted negations in the compiled program.
Get the number of key switches in the compiled program.
"""
return self._compilation_feedback.total_encrypted_negation_count
return self._compilation_feedback.count(
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
)
@property
def statistics(self) -> dict:
def key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get all program statistics in a dict.
Get the number of key switches per parameter in the compiled program.
"""
return {
"size_of_secret_keys": self.size_of_secret_keys,
"size_of_bootstrap_keys": self.size_of_bootstrap_keys,
"size_of_keyswitch_keys": self.size_of_keyswitch_keys,
"size_of_inputs": self.size_of_inputs,
"size_of_outputs": self.size_of_outputs,
"p_error": self.p_error,
"global_p_error": self.global_p_error,
"complexity": self.complexity,
"total_pbs_count": self.total_pbs_count,
"total_ks_count": self.total_ks_count,
"total_clear_addition_count": self.total_clear_addition_count,
"total_encrypted_addition_count": self.total_encrypted_addition_count,
"total_clear_multiplication_count": self.total_clear_multiplication_count,
"total_encrypted_negation_count": self.total_encrypted_negation_count,
}
return self._compilation_feedback.count_per_parameter(
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
key_types={KeyType.KEY_SWITCH},
client_parameters=self.client_specs.client_parameters,
)
@property
def key_switch_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of key switches per tag in the compiled program.
"""
return self._compilation_feedback.count_per_tag(
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
)
@property
def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of key switches per tag per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_tag_per_parameter(
operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS},
key_types={KeyType.KEY_SWITCH},
client_parameters=self.client_specs.client_parameters,
)
# Packing Key Switch Statistics
@property
def packing_key_switch_count(self) -> int:
"""
Get the number of packing key switches in the compiled program.
"""
return self._compilation_feedback.count(operations={PrimitiveOperation.WOP_PBS})
@property
def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of packing key switches per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_parameter(
operations={PrimitiveOperation.WOP_PBS},
key_types={KeyType.PACKING_KEY_SWITCH},
client_parameters=self.client_specs.client_parameters,
)
@property
def packing_key_switch_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of packing key switches per tag in the compiled program.
"""
return self._compilation_feedback.count_per_tag(operations={PrimitiveOperation.WOP_PBS})
@property
def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of packing key switches per tag per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_tag_per_parameter(
operations={PrimitiveOperation.WOP_PBS},
key_types={KeyType.PACKING_KEY_SWITCH},
client_parameters=self.client_specs.client_parameters,
)
# Clear Addition Statistics
@property
def clear_addition_count(self) -> int:
"""
Get the number of clear additions in the compiled program.
"""
return self._compilation_feedback.count(operations={PrimitiveOperation.CLEAR_ADDITION})
@property
def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of clear additions per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_parameter(
operations={PrimitiveOperation.CLEAR_ADDITION},
key_types={KeyType.SECRET},
client_parameters=self.client_specs.client_parameters,
)
@property
def clear_addition_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of clear additions per tag in the compiled program.
"""
return self._compilation_feedback.count_per_tag(
operations={PrimitiveOperation.CLEAR_ADDITION},
)
@property
def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of clear additions per tag per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_tag_per_parameter(
operations={PrimitiveOperation.CLEAR_ADDITION},
key_types={KeyType.SECRET},
client_parameters=self.client_specs.client_parameters,
)
# Encrypted Addition Statistics
@property
def encrypted_addition_count(self) -> int:
"""
Get the number of encrypted additions in the compiled program.
"""
return self._compilation_feedback.count(operations={PrimitiveOperation.ENCRYPTED_ADDITION})
@property
def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of encrypted additions per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_parameter(
operations={PrimitiveOperation.ENCRYPTED_ADDITION},
key_types={KeyType.SECRET},
client_parameters=self.client_specs.client_parameters,
)
@property
def encrypted_addition_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of encrypted additions per tag in the compiled program.
"""
return self._compilation_feedback.count_per_tag(
operations={PrimitiveOperation.ENCRYPTED_ADDITION},
)
@property
def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of encrypted additions per tag per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_tag_per_parameter(
operations={PrimitiveOperation.ENCRYPTED_ADDITION},
key_types={KeyType.SECRET},
client_parameters=self.client_specs.client_parameters,
)
# Clear Multiplication Statistics
@property
def clear_multiplication_count(self) -> int:
"""
Get the number of clear multiplications in the compiled program.
"""
return self._compilation_feedback.count(
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
)
@property
def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of clear multiplications per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_parameter(
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
key_types={KeyType.SECRET},
client_parameters=self.client_specs.client_parameters,
)
@property
def clear_multiplication_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of clear multiplications per tag in the compiled program.
"""
return self._compilation_feedback.count_per_tag(
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
)
@property
def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of clear multiplications per tag per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_tag_per_parameter(
operations={PrimitiveOperation.CLEAR_MULTIPLICATION},
key_types={KeyType.SECRET},
client_parameters=self.client_specs.client_parameters,
)
# Encrypted Negation Statistics
@property
def encrypted_negation_count(self) -> int:
"""
Get the number of encrypted negations in the compiled program.
"""
return self._compilation_feedback.count(operations={PrimitiveOperation.ENCRYPTED_NEGATION})
@property
def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]:
"""
Get the number of encrypted negations per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_parameter(
operations={PrimitiveOperation.ENCRYPTED_NEGATION},
key_types={KeyType.SECRET},
client_parameters=self.client_specs.client_parameters,
)
@property
def encrypted_negation_count_per_tag(self) -> Dict[str, int]:
"""
Get the number of encrypted negations per tag in the compiled program.
"""
return self._compilation_feedback.count_per_tag(
operations={PrimitiveOperation.ENCRYPTED_NEGATION},
)
@property
def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]:
"""
Get the number of encrypted negations per tag per parameter in the compiled program.
"""
return self._compilation_feedback.count_per_tag_per_parameter(
operations={PrimitiveOperation.ENCRYPTED_NEGATION},
key_types={KeyType.SECRET},
client_parameters=self.client_specs.client_parameters,
)
# All Statistics
@property
def statistics(self) -> Dict:
"""
Get all statistics of the compiled program.
"""
attributes = [
"size_of_secret_keys",
"size_of_bootstrap_keys",
"size_of_keyswitch_keys",
"size_of_inputs",
"size_of_outputs",
"p_error",
"global_p_error",
"complexity",
"programmable_bootstrap_count",
"programmable_bootstrap_count_per_parameter",
"programmable_bootstrap_count_per_tag",
"programmable_bootstrap_count_per_tag_per_parameter",
"key_switch_count",
"key_switch_count_per_parameter",
"key_switch_count_per_tag",
"key_switch_count_per_tag_per_parameter",
"packing_key_switch_count",
"packing_key_switch_count_per_parameter",
"packing_key_switch_count_per_tag",
"packing_key_switch_count_per_tag_per_parameter",
"clear_addition_count",
"clear_addition_count_per_parameter",
"clear_addition_count_per_tag",
"clear_addition_count_per_tag_per_parameter",
"encrypted_addition_count",
"encrypted_addition_count_per_parameter",
"encrypted_addition_count_per_tag",
"encrypted_addition_count_per_tag_per_parameter",
"clear_multiplication_count",
"clear_multiplication_count_per_parameter",
"clear_multiplication_count_per_tag",
"clear_multiplication_count_per_tag_per_parameter",
"encrypted_negation_count",
"encrypted_negation_count_per_parameter",
"encrypted_negation_count_per_tag",
"encrypted_negation_count_per_tag_per_parameter",
]
return {attribute: getattr(self, attribute) for attribute in attributes}

View File

@@ -523,6 +523,20 @@ def test_circuit_compile_sim_only(helpers):
assert f(*inputset[0]) == circuit.simulate(*inputset[0])
def tagged_function(x, y, z):
"""
A tagged function to test statistics.
"""
with fhe.tag("a"):
x = fhe.univariate(lambda v: v)(x)
with fhe.tag("b"):
y = fhe.univariate(lambda v: v)(y)
with fhe.tag("c"):
z = fhe.univariate(lambda v: v)(z)
return x + y + z
@pytest.mark.parametrize(
"function,parameters,expected_statistics",
[
@@ -532,12 +546,11 @@ def test_circuit_compile_sim_only(helpers):
"x": {"status": "encrypted", "range": [0, 10], "shape": ()},
},
{
"total_pbs_count": 1,
"total_ks_count": 1,
"total_clear_addition_count": 0,
"total_encrypted_addition_count": 0,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 0,
"programmable_bootstrap_count": 1,
"clear_addition_count": 0,
"encrypted_addition_count": 0,
"clear_multiplication_count": 0,
"encrypted_negation_count": 0,
},
id="x**2 | x.is_encrypted | x.shape == ()",
),
@@ -547,11 +560,11 @@ def test_circuit_compile_sim_only(helpers):
"x": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
},
{
"total_pbs_count": 3,
"total_clear_addition_count": 0,
"total_encrypted_addition_count": 0,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 0,
"programmable_bootstrap_count": 3,
"clear_addition_count": 0,
"encrypted_addition_count": 0,
"clear_multiplication_count": 0,
"encrypted_negation_count": 0,
},
id="x**2 | x.is_encrypted | x.shape == (3,)",
),
@@ -561,11 +574,11 @@ def test_circuit_compile_sim_only(helpers):
"x": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)},
},
{
"total_pbs_count": 3 * 2,
"total_clear_addition_count": 0,
"total_encrypted_addition_count": 0,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 0,
"programmable_bootstrap_count": 3 * 2,
"clear_addition_count": 0,
"encrypted_addition_count": 0,
"clear_multiplication_count": 0,
"encrypted_negation_count": 0,
},
id="x**2 | x.is_encrypted | x.shape == (3, 2)",
),
@@ -576,11 +589,11 @@ def test_circuit_compile_sim_only(helpers):
"y": {"status": "encrypted", "range": [0, 10], "shape": ()},
},
{
"total_pbs_count": 2,
"total_clear_addition_count": 1,
"total_encrypted_addition_count": 3,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 2,
"programmable_bootstrap_count": 2,
"clear_addition_count": 1,
"encrypted_addition_count": 3,
"clear_multiplication_count": 0,
"encrypted_negation_count": 2,
},
id="x * y | x.is_encrypted | x.shape == () | y.is_encrypted | y.shape == ()",
),
@@ -591,11 +604,11 @@ def test_circuit_compile_sim_only(helpers):
"y": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
},
{
"total_pbs_count": 3 * 2,
"total_clear_addition_count": 3 * 1,
"total_encrypted_addition_count": 3 * 3,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 3 * 2,
"programmable_bootstrap_count": 3 * 2,
"clear_addition_count": 3 * 1,
"encrypted_addition_count": 3 * 3,
"clear_multiplication_count": 0,
"encrypted_negation_count": 3 * 2,
},
id="x * y | x.is_encrypted | x.shape == (3,) | y.is_encrypted | y.shape == (3,)",
),
@@ -606,14 +619,30 @@ def test_circuit_compile_sim_only(helpers):
"y": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)},
},
{
"total_pbs_count": 3 * 2 * 2,
"total_clear_addition_count": 3 * 2 * 1,
"total_encrypted_addition_count": 3 * 2 * 3,
"total_clear_multiplication_count": 0,
"total_encrypted_negation_count": 3 * 2 * 2,
"programmable_bootstrap_count": 3 * 2 * 2,
"clear_addition_count": 3 * 2 * 1,
"encrypted_addition_count": 3 * 2 * 3,
"clear_multiplication_count": 0,
"encrypted_negation_count": 3 * 2 * 2,
},
id="x * y | x.is_encrypted | x.shape == (3, 2) | y.is_encrypted | y.shape == (3, 2)",
),
pytest.param(
tagged_function,
{
"x": {"status": "encrypted", "range": [0, 2**3 - 1], "shape": ()},
"y": {"status": "encrypted", "range": [0, 2**4 - 1], "shape": ()},
"z": {"status": "encrypted", "range": [0, 2**5 - 1], "shape": ()},
},
{
"programmable_bootstrap_count_per_tag": {
"a": 3,
"a.b": 2,
"a.b.c": 1,
},
},
id="tagged_function",
),
],
)
def test_statistics(function, parameters, expected_statistics, helpers):