feat(compiler): support multi-circuit compilation

This commit is contained in:
Alexandre Péré
2024-02-09 11:39:06 +01:00
committed by Alexandre Péré
parent 3247a28d9d
commit 9b5a2e46da
78 changed files with 1200 additions and 865 deletions

View File

@@ -286,7 +286,6 @@ struct LibraryCompilationResult {
/// The output directory path where the compilation artifacts have been
/// generated.
std::string outputDirPath;
std::string funcName;
};
class LibrarySupport {
@@ -318,14 +317,8 @@ public:
return std::move(err);
}
if (!options.mainFuncName.has_value()) {
return StreamStringError("Need to have a funcname to compile library");
}
this->funcName = options.mainFuncName.value();
auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = *options.mainFuncName;
return std::move(result);
}
@@ -356,20 +349,15 @@ public:
return std::move(err);
}
if (!options.mainFuncName.has_value()) {
return StreamStringError("Need to have a funcname to compile library");
}
this->funcName = options.mainFuncName.value();
auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = *options.mainFuncName;
return std::move(result);
}
/// Load the server lambda from the compilation result.
llvm::Expected<::concretelang::serverlib::ServerLambda>
loadServerLambda(LibraryCompilationResult &result, bool useSimulation) {
loadServerLambda(LibraryCompilationResult &result, std::string circuitName,
bool useSimulation) {
EXPECTED_TRY(auto programInfo, getProgramInfo());
EXPECTED_TRY(ServerProgram serverProgram,
outcomeToExpected(ServerProgram::load(programInfo.asReader(),
@@ -377,7 +365,7 @@ public:
useSimulation)));
EXPECTED_TRY(
ServerCircuit serverCircuit,
outcomeToExpected(serverProgram.getServerCircuit(result.funcName)));
outcomeToExpected(serverProgram.getServerCircuit(circuitName)));
return ::concretelang::serverlib::ServerLambda{serverCircuit,
useSimulation};
}
@@ -386,13 +374,6 @@ public:
llvm::Expected<::concretelang::clientlib::ClientParameters>
loadClientParameters(LibraryCompilationResult &result) {
EXPECTED_TRY(auto programInfo, getProgramInfo());
if (programInfo.asReader().getCircuits().size() > 1) {
return StreamStringError("ClientLambda: Provided program info contains "
"more than one circuit.");
}
if (programInfo.asReader().getCircuits()[0].getName() != result.funcName) {
return StreamStringError("Unexpected circuit name in program info");
}
auto secretKeys =
std::vector<::concretelang::clientlib::LweSecretKeyParam>();
for (auto key : programInfo.asReader().getKeyset().getLweSecretKeys()) {
@@ -441,15 +422,14 @@ public:
loadCompilationResult() {
auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = funcName;
return std::move(result);
}
llvm::Expected<CompilationFeedback>
llvm::Expected<ProgramCompilationFeedback>
loadCompilationFeedback(LibraryCompilationResult &result) {
auto path = CompilerEngine::Library::getCompilationFeedbackPath(
result.outputDirPath);
auto feedback = CompilationFeedback::load(path);
auto feedback = ProgramCompilationFeedback::load(path);
if (feedback.has_error()) {
return StreamStringError(feedback.error().mesg);
}
@@ -499,7 +479,6 @@ public:
private:
std::string outputPath;
std::string funcName;
std::string runtimeLibraryPath;
/// Flags to select generated artifacts
bool generateSharedLib;

View File

@@ -240,6 +240,7 @@ private:
template struct Message<concreteprotocol::ProgramInfo>;
template struct Message<concreteprotocol::CircuitEncodingInfo>;
template struct Message<concreteprotocol::ProgramEncodingInfo>;
template struct Message<concreteprotocol::Value>;
template struct Message<concreteprotocol::GateInfo>;

View File

@@ -15,7 +15,7 @@ namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createMemoryUsagePass(CompilationFeedback &feedback);
createMemoryUsagePass(ProgramCompilationFeedback &feedback);
} // namespace concretelang
} // namespace mlir

View File

@@ -15,7 +15,7 @@ namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createStatisticExtractionPass(CompilationFeedback &feedback);
createStatisticExtractionPass(ProgramCompilationFeedback &feedback);
} // namespace concretelang
} // namespace mlir

View File

@@ -44,27 +44,13 @@ enum class KeyType {
struct Statistic {
std::string location;
PrimitiveOperation operation;
std::vector<std::pair<KeyType, size_t>> keys;
size_t count;
std::vector<std::pair<KeyType, int64_t>> keys;
int64_t count;
};
struct CompilationFeedback {
double complexity;
/// @brief Probability of error for every PBS.
double pError;
/// @brief Probability of error for the whole programs.
double globalPError;
/// @brief the total number of bytes of secret keys
uint64_t totalSecretKeysSize;
/// @brief the total number of bytes of bootstrap keys
uint64_t totalBootstrapKeysSize;
/// @brief the total number of bytes of keyswitch keys
uint64_t totalKeyswitchKeysSize;
struct CircuitCompilationFeedback {
/// @brief the name of circuit.
std::string name;
/// @brief the total number of bytes of inputs
uint64_t totalInputsSize;
@@ -81,25 +67,52 @@ struct CompilationFeedback {
/// @brief memory usage per location
std::map<std::string, int64_t> memoryUsagePerLoc;
/// Fill the sizes from the program info.
void fillFromCircuitInfo(concreteprotocol::CircuitInfo::Reader params);
};
struct ProgramCompilationFeedback {
double complexity;
/// @brief Probability of error for every PBS.
double pError;
/// @brief Probability of error for the whole programs.
double globalPError;
/// @brief the total number of bytes of secret keys
uint64_t totalSecretKeysSize;
/// @brief the total number of bytes of bootstrap keys
uint64_t totalBootstrapKeysSize;
/// @brief the total number of bytes of keyswitch keys
uint64_t totalKeyswitchKeysSize;
/// @brief the feedback for each circuit
std::vector<CircuitCompilationFeedback> circuitFeedbacks;
/// Fill the sizes from the program info.
void fillFromProgramInfo(const Message<protocol::ProgramInfo> &params);
/// Load the compilation feedback from a path
static outcome::checked<CompilationFeedback, StringError>
static outcome::checked<ProgramCompilationFeedback, StringError>
load(std::string path);
};
llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &);
llvm::json::Value
toJSON(const mlir::concretelang::ProgramCompilationFeedback &);
bool fromJSON(const llvm::json::Value,
mlir::concretelang::CompilationFeedback &, llvm::json::Path);
mlir::concretelang::ProgramCompilationFeedback &,
llvm::json::Path);
} // namespace concretelang
} // namespace mlir
static inline llvm::raw_ostream &
operator<<(llvm::raw_string_ostream &OS,
mlir::concretelang::CompilationFeedback cp) {
mlir::concretelang::ProgramCompilationFeedback cp) {
return OS << llvm::formatv("{0:2}", toJSON(cp));
}
#endif

View File

@@ -79,8 +79,6 @@ struct CompilationOptions {
std::optional<std::vector<int64_t>> fhelinalgTileSizes;
std::optional<std::string> mainFuncName;
optimizer::Config optimizerConfig;
/// When decomposing big integers into chunks, chunkSize is the total number
@@ -92,7 +90,9 @@ struct CompilationOptions {
/// When compiling from a dialect lower than FHE, one needs to provide
/// encodings info manually to allow the client lib to be generated.
std::optional<Message<concreteprotocol::CircuitEncodingInfo>> encodings;
std::optional<Message<concreteprotocol::ProgramEncodingInfo>> encodings;
bool skipProgramInfo;
bool compressEvaluationKeys;
@@ -102,20 +102,14 @@ struct CompilationOptions {
maxBatchSize(std::numeric_limits<int64_t>::max()), emitSDFGOps(false),
unrollLoopsWithSDFGConvertibleOps(false), dataflowParallelize(false),
optimizeTFHE(true), simulate(false), emitGPUOps(false),
mainFuncName(std::nullopt), optimizerConfig(optimizer::DEFAULT_CONFIG),
chunkIntegers(false), chunkSize(4), chunkWidth(2),
encodings(std::nullopt), compressEvaluationKeys(false){};
CompilationOptions(std::string funcname) : CompilationOptions() {
mainFuncName = funcname;
}
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
chunkSize(4), chunkWidth(2), encodings(std::nullopt),
skipProgramInfo(false), compressEvaluationKeys(false){};
/// @brief Constructor for CompilationOptions with default parameters for a
/// specific backend.
/// @param funcname The name of the function to compile.
/// @param backend The backend to target.
CompilationOptions(std::string funcname, enum Backend backend)
: CompilationOptions(funcname) {
CompilationOptions(enum Backend backend) : CompilationOptions() {
switch (backend) {
case Backend::CPU:
loopParallelize = true;
@@ -143,7 +137,7 @@ public:
std::optional<mlir::OwningOpRef<mlir::ModuleOp>> mlirModuleRef;
std::optional<Message<concreteprotocol::ProgramInfo>> programInfo;
std::optional<CompilationFeedback> feedback;
std::optional<ProgramCompilationFeedback> feedback;
std::unique_ptr<llvm::Module> llvmModule;
std::optional<mlir::concretelang::V0FHEContext> fheContext;
@@ -157,7 +151,7 @@ public:
/// Path to the runtime library. Will be linked to the output library if set
std::string runtimeLibraryPath;
bool cleanUp;
mlir::concretelang::CompilationFeedback compilationFeedback;
mlir::concretelang::ProgramCompilationFeedback compilationFeedback;
Message<concreteprotocol::ProgramInfo> programInfo;
public:
@@ -280,7 +274,7 @@ public:
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(), compilerOptions(),
generateProgramInfo(compilerOptions.mainFuncName.has_value()),
generateProgramInfo(true),
enablePass([](mlir::Pass *pass) { return true; }),
compilationContext(compilationContext) {}
@@ -325,10 +319,6 @@ public:
if (compilerOptions.v0FHEConstraints.has_value()) {
setFHEConstraints(*compilerOptions.v0FHEConstraints);
}
if (compilerOptions.mainFuncName.has_value()) {
setGenerateProgramInfo(true);
}
}
CompilationOptions &getCompilationOptions() { return compilerOptions; }

View File

@@ -35,11 +35,11 @@ namespace mlir {
namespace concretelang {
namespace encodings {
llvm::Expected<Message<concreteprotocol::CircuitEncodingInfo>>
getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module);
llvm::Expected<Message<concreteprotocol::ProgramEncodingInfo>>
getProgramEncoding(mlir::ModuleOp module);
void setCircuitEncodingModes(
Message<concreteprotocol::CircuitEncodingInfo> &info,
void setProgramEncodingModes(
Message<concreteprotocol::ProgramEncodingInfo> &info,
std::optional<
Message<concreteprotocol::IntegerCiphertextEncodingInfo::ChunkedMode>>
maybeChunk,

View File

@@ -77,7 +77,7 @@ normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
extractTFHEStatistics(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
CompilationFeedback &feedback);
ProgramCompilationFeedback &feedback);
mlir::LogicalResult
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
@@ -86,7 +86,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
CompilationFeedback &feedback);
ProgramCompilationFeedback &feedback);
mlir::LogicalResult
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,

View File

@@ -21,8 +21,8 @@ namespace concretelang {
llvm::Expected<Message<concreteprotocol::ProgramInfo>>
createProgramInfoFromTfheDialect(
mlir::ModuleOp module, llvm::StringRef functionName, int bitsOfSecurity,
Message<concreteprotocol::CircuitEncodingInfo> &encodings,
mlir::ModuleOp module, int bitsOfSecurity,
const Message<concreteprotocol::ProgramEncodingInfo> &encodings,
bool compressEvaluationKeys);
} // namespace concretelang

View File

@@ -151,10 +151,10 @@ typedef std::variant<V0Parameter, CircuitSolution> Solution;
} // namespace optimizer
struct CompilationFeedback;
struct ProgramCompilationFeedback;
llvm::Expected<optimizer::Solution>
getSolution(optimizer::Description &descr, CompilationFeedback &feedback,
getSolution(optimizer::Description &descr, ProgramCompilationFeedback &feedback,
optimizer::Config optimizerConfig);
// As for now the solution which contains a crt encoding is mono parameter only

View File

@@ -36,26 +36,26 @@ using concretelang::values::Value;
namespace concretelang {
namespace testlib {
class TestCircuit {
class TestProgram {
public:
TestCircuit(mlir::concretelang::CompilationOptions options)
TestProgram(mlir::concretelang::CompilationOptions options)
: artifactDirectory(createTempFolderIn(getSystemTempFolderPath())),
compiler(mlir::concretelang::CompilationContext::createShared()),
encryptionCsprng(std::make_shared<csprng::EncryptionCSPRNG>(0)) {
compiler.setCompilationOptions(options);
}
TestCircuit(TestCircuit &&tc)
TestProgram(TestProgram &&tc)
: artifactDirectory(tc.artifactDirectory), compiler(tc.compiler),
library(tc.library), keyset(tc.keyset),
encryptionCsprng(tc.encryptionCsprng) {
tc.artifactDirectory = "";
};
TestCircuit(TestCircuit &tc) = delete;
TestProgram(TestProgram &tc) = delete;
~TestCircuit() {
~TestProgram() {
auto d = getArtifactDirectory();
if (d.empty())
return;
@@ -93,16 +93,17 @@ public:
return outcome::success();
}
Result<std::vector<Value>> call(std::vector<Value> inputs) {
Result<std::vector<Value>> call(std::vector<Value> inputs,
std::string name = "main") {
// preprocess arguments
auto preparedArgs = std::vector<TransportValue>();
OUTCOME_TRY(auto clientCircuit, getClientCircuit());
OUTCOME_TRY(auto clientCircuit, getClientCircuit(name));
for (size_t i = 0; i < inputs.size(); i++) {
OUTCOME_TRY(auto preparedInput, clientCircuit.prepareInput(inputs[i], i));
preparedArgs.push_back(preparedInput);
}
// Call server
OUTCOME_TRY(auto returns, callServer(preparedArgs));
OUTCOME_TRY(auto returns, callServer(preparedArgs, name));
// postprocess arguments
std::vector<Value> processedOutputs(returns.size());
for (size_t i = 0; i < processedOutputs.size(); i++) {
@@ -113,17 +114,18 @@ public:
}
Result<std::vector<Value>> compose_n_times(std::vector<Value> inputs,
size_t n) {
size_t n,
std::string name = "main") {
// preprocess arguments
auto preparedArgs = std::vector<TransportValue>();
OUTCOME_TRY(auto clientCircuit, getClientCircuit());
OUTCOME_TRY(auto clientCircuit, getClientCircuit(name));
for (size_t i = 0; i < inputs.size(); i++) {
OUTCOME_TRY(auto preparedInput, clientCircuit.prepareInput(inputs[i], i));
preparedArgs.push_back(preparedInput);
}
// Call server multiple times in a row
for (size_t i = 0; i < n; i++) {
OUTCOME_TRY(preparedArgs, callServer(preparedArgs));
OUTCOME_TRY(preparedArgs, callServer(preparedArgs, name));
}
// postprocess arguments
std::vector<Value> processedOutputs(preparedArgs.size());
@@ -135,9 +137,9 @@ public:
}
Result<std::vector<TransportValue>>
callServer(std::vector<TransportValue> inputs) {
callServer(std::vector<TransportValue> inputs, std::string name = "main") {
std::vector<TransportValue> returns;
OUTCOME_TRY(auto serverCircuit, getServerCircuit());
OUTCOME_TRY(auto serverCircuit, getServerCircuit(name));
if (compiler.getCompilationOptions().simulate) {
OUTCOME_TRY(returns, serverCircuit.simulate(inputs));
} else {
@@ -146,29 +148,25 @@ public:
return returns;
}
Result<ClientCircuit> getClientCircuit() {
Result<ClientCircuit> getClientCircuit(std::string name = "main") {
OUTCOME_TRY(auto lib, getLibrary());
OUTCOME_TRY(auto ks, getKeyset());
auto programInfo = lib.getProgramInfo();
OUTCOME_TRY(auto clientProgram,
ClientProgram::create(programInfo, ks.client, encryptionCsprng,
isSimulation()));
OUTCOME_TRY(auto clientCircuit,
clientProgram.getClientCircuit(
programInfo.asReader().getCircuits()[0].getName()));
OUTCOME_TRY(auto clientCircuit, clientProgram.getClientCircuit(name));
return clientCircuit;
}
Result<ServerCircuit> getServerCircuit() {
Result<ServerCircuit> getServerCircuit(std::string name = "main") {
OUTCOME_TRY(auto lib, getLibrary());
auto programInfo = lib.getProgramInfo();
OUTCOME_TRY(auto serverProgram,
ServerProgram::load(programInfo,
lib.getSharedLibraryPath(artifactDirectory),
isSimulation()));
OUTCOME_TRY(auto serverCircuit,
serverProgram.getServerCircuit(
programInfo.asReader().getCircuits()[0].getName()));
OUTCOME_TRY(auto serverCircuit, serverProgram.getServerCircuit(name));
return serverCircuit;
}
@@ -177,14 +175,14 @@ private:
Result<mlir::concretelang::CompilerEngine::Library> getLibrary() {
if (!library.has_value()) {
return StringError("TestCircuit: compilation has not been done\n");
return StringError("TestProgram: compilation has not been done\n");
}
return *library;
}
Result<Keyset> getKeyset() {
if (!keyset.has_value()) {
return StringError("TestCircuit: keyset has not been generated\n");
return StringError("TestProgram: keyset has not been generated\n");
}
return *keyset;
}
@@ -206,10 +204,10 @@ private:
void deleteFolder(const std::string &folder) {
auto ec = std::error_code();
llvm::errs() << "TestCircuit: delete artifact directory(" << folder
llvm::errs() << "TestProgram: delete artifact directory(" << folder
<< ")\n";
if (!std::filesystem::remove_all(folder, ec)) {
llvm::errs() << "TestCircuit: fail to delete directory(" << folder
llvm::errs() << "TestProgram: fail to delete directory(" << folder
<< "), error(" << ec.message() << ")\n";
}
}
@@ -232,10 +230,10 @@ private:
for (size_t i = 0; i < 5; i++) {
auto pathString = new_path();
auto ec = std::error_code();
llvm::errs() << "TestCircuit: create temporary directory(" << pathString
llvm::errs() << "TestProgram: create temporary directory(" << pathString
<< ")\n";
if (!std::filesystem::create_directory(pathString, ec)) {
llvm::errs() << "TestCircuit: fail to create temporary directory("
llvm::errs() << "TestProgram: fail to create temporary directory("
<< pathString << "), ";
if (ec) {
llvm::errs() << "already exists";
@@ -243,7 +241,7 @@ private:
llvm::errs() << "error(" << ec.message() << ")";
}
} else {
llvm::errs() << "TestCircuit: directory(" << pathString
llvm::errs() << "TestProgram: directory(" << pathString
<< ") successfuly created\n";
return pathString;
}

View File

@@ -102,7 +102,8 @@ concretelang::clientlib::ClientParameters library_load_client_parameters(
return *clientParameters;
}
mlir::concretelang::CompilationFeedback library_load_compilation_feedback(
mlir::concretelang::ProgramCompilationFeedback
library_load_compilation_feedback(
LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
@@ -113,9 +114,10 @@ mlir::concretelang::CompilationFeedback library_load_compilation_feedback(
concretelang::serverlib::ServerLambda
library_load_server_lambda(LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result,
bool useSimulation) {
std::string circuitName, bool useSimulation) {
GET_OR_THROW_LLVM_EXPECTED(
serverLambda, support.support.loadServerLambda(result, useSimulation));
serverLambda,
support.support.loadServerLambda(result, circuitName, useSimulation));
return *serverLambda;
}
@@ -176,7 +178,8 @@ key_set(concretelang::clientlib::ClientParameters clientParameters,
std::unique_ptr<concretelang::clientlib::PublicArguments>
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args) {
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args,
const std::string &circuitName) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(), keySet.keyset.client,
std::make_shared<::concretelang::csprng::EncryptionCSPRNG>(
@@ -185,15 +188,10 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto circuit = maybeProgram.value()
.getClientCircuit(clientParameters.programInfo.asReader()
.getCircuits()[0]
.getName())
.value();
auto circuit = maybeProgram.value().getClientCircuit(circuitName).value();
std::vector<TransportValue> output;
for (size_t i = 0; i < args.size(); i++) {
auto info =
clientParameters.programInfo.asReader().getCircuits()[0].getInputs()[i];
auto info = circuit.getCircuitInfo().asReader().getInputs()[i];
auto typeTransformer = getPythonTypeTransformer(info);
auto input = typeTransformer(args[i]->value);
auto maybePrepared = circuit.prepareInput(input, i);
@@ -211,7 +209,8 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
std::vector<lambdaArgument>
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::PublicResult &publicResult) {
concretelang::clientlib::PublicResult &publicResult,
const std::string &circuitName) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(), keySet.keyset.client,
std::make_shared<::concretelang::csprng::EncryptionCSPRNG>(
@@ -220,11 +219,7 @@ decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto circuit = maybeProgram.value()
.getClientCircuit(clientParameters.programInfo.asReader()
.getCircuits()[0]
.getName())
.value();
auto circuit = maybeProgram.value().getClientCircuit(circuitName).value();
std::vector<lambdaArgument> results;
for (auto e : llvm::enumerate(publicResult.values)) {
auto maybeProcessed = circuit.processOutput(e.value(), e.index());
@@ -370,9 +365,10 @@ valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value) {
return maybeString.value();
}
concretelang::clientlib::ValueExporter createValueExporter(
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::ClientParameters &clientParameters) {
concretelang::clientlib::ValueExporter
createValueExporter(concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::ClientParameters &clientParameters,
const std::string &circuitName) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(), keySet.keyset.client,
std::make_shared<::concretelang::csprng::EncryptionCSPRNG>(
@@ -381,13 +377,13 @@ concretelang::clientlib::ValueExporter createValueExporter(
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto maybeCircuit = maybeProgram.value().getClientCircuit(
clientParameters.programInfo.asReader().getCircuits()[0].getName());
auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName);
return ::concretelang::clientlib::ValueExporter{maybeCircuit.value()};
}
concretelang::clientlib::SimulatedValueExporter createSimulatedValueExporter(
concretelang::clientlib::ClientParameters &clientParameters) {
concretelang::clientlib::ClientParameters &clientParameters,
const std::string &circuitName) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo, ::concretelang::keysets::ClientKeyset(),
@@ -397,15 +393,15 @@ concretelang::clientlib::SimulatedValueExporter createSimulatedValueExporter(
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto maybeCircuit = maybeProgram.value().getClientCircuit(
clientParameters.programInfo.asReader().getCircuits()[0].getName());
auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName);
return ::concretelang::clientlib::SimulatedValueExporter{
maybeCircuit.value()};
}
concretelang::clientlib::ValueDecrypter createValueDecrypter(
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::ClientParameters &clientParameters) {
concretelang::clientlib::ClientParameters &clientParameters,
const std::string &circuitName) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(), keySet.keyset.client,
@@ -415,13 +411,13 @@ concretelang::clientlib::ValueDecrypter createValueDecrypter(
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto maybeCircuit = maybeProgram.value().getClientCircuit(
clientParameters.programInfo.asReader().getCircuits()[0].getName());
auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName);
return ::concretelang::clientlib::ValueDecrypter{maybeCircuit.value()};
}
concretelang::clientlib::SimulatedValueDecrypter createSimulatedValueDecrypter(
concretelang::clientlib::ClientParameters &clientParameters) {
concretelang::clientlib::ClientParameters &clientParameters,
const std::string &circuitName) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(),
@@ -432,8 +428,7 @@ concretelang::clientlib::SimulatedValueDecrypter createSimulatedValueDecrypter(
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto maybeCircuit = maybeProgram.value().getClientCircuit(
clientParameters.programInfo.asReader().getCircuits()[0].getName());
auto maybeCircuit = maybeProgram.value().getClientCircuit(circuitName);
return ::concretelang::clientlib::SimulatedValueDecrypter{
maybeCircuit.value()};
}
@@ -699,14 +694,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.export_values();
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
.def(pybind11::init(
[](std::string funcname, mlir::concretelang::Backend backend) {
return CompilationOptions(funcname, backend);
}))
.def("set_funcname",
[](CompilationOptions &options, std::string funcname) {
options.mainFuncName = funcname;
})
.def(pybind11::init([](mlir::concretelang::Backend backend) {
return CompilationOptions(backend);
}))
.def("set_verify_diagnostics",
[](CompilationOptions &options, bool b) {
options.verifyDiagnostics = b;
@@ -828,34 +818,46 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def_readonly("keys", &mlir::concretelang::Statistic::keys)
.def_readonly("count", &mlir::concretelang::Statistic::count);
pybind11::class_<mlir::concretelang::CompilationFeedback>(
m, "CompilationFeedback")
pybind11::class_<mlir::concretelang::ProgramCompilationFeedback>(
m, "ProgramCompilationFeedback")
.def_readonly("complexity",
&mlir::concretelang::CompilationFeedback::complexity)
.def_readonly("p_error", &mlir::concretelang::CompilationFeedback::pError)
.def_readonly("global_p_error",
&mlir::concretelang::CompilationFeedback::globalPError)
&mlir::concretelang::ProgramCompilationFeedback::complexity)
.def_readonly("p_error",
&mlir::concretelang::ProgramCompilationFeedback::pError)
.def_readonly(
"global_p_error",
&mlir::concretelang::ProgramCompilationFeedback::globalPError)
.def_readonly(
"total_secret_keys_size",
&mlir::concretelang::CompilationFeedback::totalSecretKeysSize)
&mlir::concretelang::ProgramCompilationFeedback::totalSecretKeysSize)
.def_readonly("total_bootstrap_keys_size",
&mlir::concretelang::ProgramCompilationFeedback::
totalBootstrapKeysSize)
.def_readonly("total_keyswitch_keys_size",
&mlir::concretelang::ProgramCompilationFeedback::
totalKeyswitchKeysSize)
.def_readonly(
"total_bootstrap_keys_size",
&mlir::concretelang::CompilationFeedback::totalBootstrapKeysSize)
"circuit_feedbacks",
&mlir::concretelang::ProgramCompilationFeedback::circuitFeedbacks);
pybind11::class_<mlir::concretelang::CircuitCompilationFeedback>(
m, "CircuitCompilationFeedback")
.def_readonly("name",
&mlir::concretelang::CircuitCompilationFeedback::name)
.def_readonly(
"total_keyswitch_keys_size",
&mlir::concretelang::CompilationFeedback::totalKeyswitchKeysSize)
.def_readonly("total_inputs_size",
&mlir::concretelang::CompilationFeedback::totalInputsSize)
.def_readonly("total_output_size",
&mlir::concretelang::CompilationFeedback::totalOutputsSize)
"total_inputs_size",
&mlir::concretelang::CircuitCompilationFeedback::totalInputsSize)
.def_readonly(
"crt_decompositions_of_outputs",
&mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs)
"total_output_size",
&mlir::concretelang::CircuitCompilationFeedback::totalOutputsSize)
.def_readonly("crt_decompositions_of_outputs",
&mlir::concretelang::CircuitCompilationFeedback::
crtDecompositionsOfOutputs)
.def_readonly("statistics",
&mlir::concretelang::CompilationFeedback::statistics)
&mlir::concretelang::CircuitCompilationFeedback::statistics)
.def_readonly(
"memory_usage_per_location",
&mlir::concretelang::CompilationFeedback::memoryUsagePerLoc);
&mlir::concretelang::CircuitCompilationFeedback::memoryUsagePerLoc);
pybind11::class_<mlir::concretelang::CompilationContext,
std::shared_ptr<mlir::concretelang::CompilationContext>>(
@@ -872,11 +874,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
pybind11::class_<mlir::concretelang::LibraryCompilationResult>(
m, "LibraryCompilationResult")
.def(pybind11::init([](std::string outputDirPath, std::string funcname) {
return mlir::concretelang::LibraryCompilationResult{
outputDirPath,
funcname,
};
.def(pybind11::init([](std::string outputDirPath) {
return mlir::concretelang::LibraryCompilationResult{outputDirPath};
}));
pybind11::class_<::concretelang::serverlib::ServerLambda>(m, "LibraryLambda");
pybind11::class_<LibrarySupport_Py>(m, "LibrarySupport")
@@ -920,8 +919,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
"load_server_lambda",
[](LibrarySupport_Py &support,
mlir::concretelang::LibraryCompilationResult &result,
bool useSimulation) {
return library_load_server_lambda(support, result, useSimulation);
std::string circuitName, bool useSimulation) {
return library_load_server_lambda(support, result, circuitName,
useSimulation);
},
pybind11::return_value_policy::reference)
.def("server_call",
@@ -974,19 +974,22 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
"encrypt_arguments",
[](::concretelang::clientlib::ClientParameters clientParameters,
::concretelang::clientlib::KeySet &keySet,
std::vector<lambdaArgument> args) {
std::vector<lambdaArgument> args, const std::string &circuitName) {
std::vector<mlir::concretelang::LambdaArgument *> argsRef;
for (auto i = 0u; i < args.size(); i++) {
argsRef.push_back(args[i].ptr.get());
}
return encrypt_arguments(clientParameters, keySet, argsRef);
return encrypt_arguments(clientParameters, keySet, argsRef,
circuitName);
})
.def_static(
"decrypt_result",
[](::concretelang::clientlib::ClientParameters clientParameters,
::concretelang::clientlib::KeySet &keySet,
::concretelang::clientlib::PublicResult &publicResult) {
return decrypt_result(clientParameters, keySet, publicResult);
::concretelang::clientlib::PublicResult &publicResult,
const std::string &circuitName) {
return decrypt_result(clientParameters, keySet, publicResult,
circuitName);
});
pybind11::class_<::concretelang::clientlib::KeySetCache>(m, "KeySetCache")
.def(pybind11::init<std::string &>());
@@ -1187,8 +1190,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def_static(
"create",
[](::concretelang::clientlib::KeySet &keySet,
::concretelang::clientlib::ClientParameters &clientParameters) {
return createValueExporter(keySet, clientParameters);
::concretelang::clientlib::ClientParameters &clientParameters,
const std::string &circuitName) {
return createValueExporter(keySet, clientParameters, circuitName);
})
.def("export_scalar",
[](::concretelang::clientlib::ValueExporter &exporter,
@@ -1233,8 +1237,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
m, "SimulatedValueExporter")
.def_static(
"create",
[](::concretelang::clientlib::ClientParameters &clientParameters) {
return createSimulatedValueExporter(clientParameters);
[](::concretelang::clientlib::ClientParameters &clientParameters,
const std::string &circuitName) {
return createSimulatedValueExporter(clientParameters, circuitName);
})
.def("export_scalar",
[](::concretelang::clientlib::SimulatedValueExporter &exporter,
@@ -1279,8 +1284,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def_static(
"create",
[](::concretelang::clientlib::KeySet &keySet,
::concretelang::clientlib::ClientParameters &clientParameters) {
return createValueDecrypter(keySet, clientParameters);
::concretelang::clientlib::ClientParameters &clientParameters,
const std::string &circuitName) {
return createValueDecrypter(keySet, clientParameters, circuitName);
})
.def("decrypt",
[](::concretelang::clientlib::ValueDecrypter &decrypter,
@@ -1303,8 +1309,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
m, "SimulatedValueDecrypter")
.def_static(
"create",
[](::concretelang::clientlib::ClientParameters &clientParameters) {
return createSimulatedValueDecrypter(clientParameters);
[](::concretelang::clientlib::ClientParameters &clientParameters,
const std::string &circuitName) {
return createSimulatedValueDecrypter(clientParameters, circuitName);
})
.def("decrypt",
[](::concretelang::clientlib::SimulatedValueDecrypter &decrypter,

View File

@@ -21,7 +21,7 @@ from .compilation_options import CompilationOptions, Encoding
from .compilation_context import CompilationContext
from .key_set_cache import KeySetCache
from .client_parameters import ClientParameters
from .compilation_feedback import CompilationFeedback
from .compilation_feedback import ProgramCompilationFeedback, CircuitCompilationFeedback
from .key_set import KeySet
from .public_result import PublicResult
from .public_arguments import PublicArguments

View File

@@ -116,6 +116,7 @@ class ClientSupport(WrapperCpp):
client_parameters: ClientParameters,
keyset: KeySet,
args: List[Union[int, np.ndarray]],
circuit_name: str = "main",
) -> PublicArguments:
"""Prepare arguments for encrypted computation.
@@ -126,10 +127,12 @@ class ClientSupport(WrapperCpp):
client_parameters (ClientParameters): client parameters specification
keyset (KeySet): keyset used to encrypt arguments that require encryption
args (List[Union[int, np.ndarray]]): list of scalar or tensor arguments
circuit_name(str): the name of the circuit for which to encrypt
Raises:
TypeError: if client_parameters is not of type ClientParameters
TypeError: if keyset is not of type KeySet
TypeError: if circuit_name is not of type str
Returns:
PublicArguments: public arguments for execution
@@ -140,6 +143,10 @@ class ClientSupport(WrapperCpp):
)
if not isinstance(keyset, KeySet):
raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}")
if not isinstance(circuit_name, str):
raise TypeError(
f"circuit_name must be of type str, not {type(circuit_name)}"
)
signs = client_parameters.input_signs()
if len(signs) != len(args):
@@ -156,6 +163,7 @@ class ClientSupport(WrapperCpp):
client_parameters.cpp(),
keyset.cpp(),
[arg.cpp() for arg in lambda_arguments],
circuit_name,
)
)
@@ -164,17 +172,20 @@ class ClientSupport(WrapperCpp):
client_parameters: ClientParameters,
keyset: KeySet,
public_result: PublicResult,
circuit_name: str = "main",
) -> Union[int, np.ndarray]:
"""Decrypt a public result using the keyset.
Args:
client_parameters (ClientParameters): client parameters for decryption
keyset (KeySet): keyset used for decryption
public_result: public result to decrypt
public_result (PublicResult): public result to decrypt
circuit_name (str): name of the circuit for which to decrypt
Raises:
TypeError: if keyset is not of type KeySet
TypeError: if public_result is not of type PublicResult
TypeError: if circuit_name is not of type str
RuntimeError: if the result is of an unknown type
Returns:
@@ -186,8 +197,12 @@ class ClientSupport(WrapperCpp):
raise TypeError(
f"public_result must be of type PublicResult, not {type(public_result)}"
)
if not isinstance(circuit_name, str):
raise TypeError(
f"circuit_name must be of type str, not {type(circuit_name)}"
)
results = _ClientSupport.decrypt_result(
client_parameters.cpp(), keyset.cpp(), public_result.cpp()
client_parameters.cpp(), keyset.cpp(), public_result.cpp(), circuit_name
)
def process_result(result):

View File

@@ -8,7 +8,8 @@ from typing import Dict, Set
# pylint: disable=no-name-in-module,import-error,too-many-instance-attributes
from mlir._mlir_libs._concretelang._compiler import (
CompilationFeedback as _CompilationFeedback,
ProgramCompilationFeedback as _ProgramCompilationFeedback,
CircuitCompilationFeedback as _CircuitCompilationFeedback,
KeyType,
PrimitiveOperation,
)
@@ -39,38 +40,36 @@ def tag_from_location(location):
return tag
class CompilationFeedback(WrapperCpp):
"""CompilationFeedback is a set of hint computed by the compiler engine."""
class CircuitCompilationFeedback(WrapperCpp):
"""CircuitCompilationFeedback is a set of hint computed by the compiler engine for a circuit."""
def __init__(self, compilation_feedback: _CompilationFeedback):
def __init__(self, circuit_compilation_feedback: _CircuitCompilationFeedback):
"""Wrap the native Cpp object.
Args:
compilation_feeback (_CompilationFeedback): object to wrap
circuit_compilation_feeback (_CircuitCompilationFeedback): object to wrap
Raises:
TypeError: if compilation_feedback is not of type _CompilationFeedback
TypeError: if circuit_compilation_feedback is not of type _CircuitCompilationFeedback
"""
if not isinstance(compilation_feedback, _CompilationFeedback):
if not isinstance(circuit_compilation_feedback, _CircuitCompilationFeedback):
raise TypeError(
f"compilation_feedback must be of type _CompilationFeedback, not {type(compilation_feedback)}"
"circuit_compilation_feedback must be of type "
f"_CircuitCompilationFeedback, not {type(circuit_compilation_feedback)}"
)
self.complexity = compilation_feedback.complexity
self.p_error = compilation_feedback.p_error
self.global_p_error = compilation_feedback.global_p_error
self.total_secret_keys_size = compilation_feedback.total_secret_keys_size
self.total_bootstrap_keys_size = compilation_feedback.total_bootstrap_keys_size
self.total_keyswitch_keys_size = compilation_feedback.total_keyswitch_keys_size
self.total_inputs_size = compilation_feedback.total_inputs_size
self.total_output_size = compilation_feedback.total_output_size
self.name = circuit_compilation_feedback.name
self.total_inputs_size = circuit_compilation_feedback.total_inputs_size
self.total_output_size = circuit_compilation_feedback.total_output_size
self.crt_decompositions_of_outputs = (
compilation_feedback.crt_decompositions_of_outputs
circuit_compilation_feedback.crt_decompositions_of_outputs
)
self.statistics = circuit_compilation_feedback.statistics
self.memory_usage_per_location = (
circuit_compilation_feedback.memory_usage_per_location
)
self.statistics = compilation_feedback.statistics
self.memory_usage_per_location = compilation_feedback.memory_usage_per_location
super().__init__(compilation_feedback)
super().__init__(circuit_compilation_feedback)
def count(self, *, operations: Set[PrimitiveOperation]) -> int:
"""
@@ -216,3 +215,68 @@ class CompilationFeedback(WrapperCpp):
result[current_tag][parameter] += statistic.count
return result
class ProgramCompilationFeedback(WrapperCpp):
"""CompilationFeedback is a set of hint computed by the compiler engine."""
def __init__(self, program_compilation_feedback: _ProgramCompilationFeedback):
"""Wrap the native Cpp object.
Args:
compilation_feeback (_CompilationFeedback): object to wrap
Raises:
TypeError: if program_compilation_feedback is not of type _CompilationFeedback
"""
if not isinstance(program_compilation_feedback, _ProgramCompilationFeedback):
raise TypeError(
"program_compilation_feedback must be of type "
f"_CompilationFeedback, not {type(program_compilation_feedback)}"
)
self.complexity = program_compilation_feedback.complexity
self.p_error = program_compilation_feedback.p_error
self.global_p_error = program_compilation_feedback.global_p_error
self.total_secret_keys_size = (
program_compilation_feedback.total_secret_keys_size
)
self.total_bootstrap_keys_size = (
program_compilation_feedback.total_bootstrap_keys_size
)
self.total_keyswitch_keys_size = (
program_compilation_feedback.total_keyswitch_keys_size
)
self.circuit_feedbacks = [
CircuitCompilationFeedback(c)
for c in program_compilation_feedback.circuit_feedbacks
]
super().__init__(program_compilation_feedback)
def circuit(self, circuit_name: str) -> CircuitCompilationFeedback:
"""
Returns the feedback for the circuit circuit_name.
Args:
circuit_name (str):
the name of the circuit.
Returns:
CircuitCompilationFeedback:
the feedback for the circuit.
Raises:
TypeError: if the circuit_name is not a string
ValueError: if there is no circuit with name circuit_name
"""
if not isinstance(circuit_name, str):
raise TypeError(
f"circuit_name must be of type str, not {type(circuit_name)}"
)
for circuit in self.circuit_feedbacks:
if circuit.name == circuit_name:
return circuit
raise ValueError(f"no circuit with name {circuit_name} found")

View File

@@ -43,11 +43,11 @@ class CompilationOptions(WrapperCpp):
@staticmethod
# pylint: disable=arguments-differ
def new(function_name="main", backend=_Backend.CPU) -> "CompilationOptions":
def new(backend=_Backend.CPU) -> "CompilationOptions":
"""Build a CompilationOptions.
Args:
function_name (str, optional): name of the entrypoint function. Defaults to "main".
backend (_Backend): backend to use.
Raises:
TypeError: if function_name is not an str
@@ -55,15 +55,9 @@ class CompilationOptions(WrapperCpp):
Returns:
CompilationOptions
"""
if not isinstance(function_name, str):
raise TypeError(
f"function_name must be of type str not {type(function_name)}"
)
if not isinstance(backend, _Backend):
raise TypeError(
f"backend must be of type Backend not {type(function_name)}"
)
return CompilationOptions.wrap(_CompilationOptions(function_name, backend))
raise TypeError(f"backend must be of type Backend not {type(backend)}")
return CompilationOptions.wrap(_CompilationOptions(backend))
# pylint: enable=arguments-differ

View File

@@ -33,16 +33,14 @@ class LibraryCompilationResult(WrapperCpp):
@staticmethod
# pylint: disable=arguments-differ
def new(output_dir_path: str, func_name: str) -> "LibraryCompilationResult":
"""Build a LibraryCompilationResult at output_dir_path, with func_name as entrypoint.
def new(output_dir_path: str) -> "LibraryCompilationResult":
"""Build a LibraryCompilationResult at output_dir_path.
Args:
output_dir_path (str): path to the compilation artifacts
func_name (str): entrypoint function name
Raises:
TypeError: if output_dir_path is not of type str
TypeError: if func_name is not of type str
Returns:
LibraryCompilationResult
@@ -51,10 +49,6 @@ class LibraryCompilationResult(WrapperCpp):
raise TypeError(
f"output_dir_path must be of type str, not {type(output_dir_path)}"
)
if not isinstance(func_name, str):
raise TypeError(f"func_name must be of type str, not {type(func_name)}")
return LibraryCompilationResult.wrap(
_LibraryCompilationResult(output_dir_path, func_name)
)
return LibraryCompilationResult.wrap(_LibraryCompilationResult(output_dir_path))
# pylint: enable=arguments-differ

View File

@@ -23,7 +23,7 @@ from .public_arguments import PublicArguments
from .library_lambda import LibraryLambda
from .public_result import PublicResult
from .client_parameters import ClientParameters
from .compilation_feedback import CompilationFeedback
from .compilation_feedback import ProgramCompilationFeedback
from .wrapper import WrapperCpp
from .utils import lookup_runtime_lib
from .evaluation_keys import EvaluationKeys
@@ -132,7 +132,7 @@ class LibrarySupport(WrapperCpp):
def compile(
self,
mlir_program: Union[str, MlirModule],
options: CompilationOptions = CompilationOptions.new("main"),
options: CompilationOptions = CompilationOptions.new(),
compilation_context: Optional[CompilationContext] = None,
) -> LibraryCompilationResult:
"""Compile an MLIR program using Concrete dialects into a library.
@@ -178,18 +178,13 @@ class LibrarySupport(WrapperCpp):
self.cpp().compile(mlir_program, options.cpp())
)
def reload(self, func_name: str = "main") -> LibraryCompilationResult:
def reload(self) -> LibraryCompilationResult:
"""Reload the library compilation result from the output_dir_path.
Args:
func_name: entrypoint function name
Returns:
LibraryCompilationResult: loaded library
"""
if not isinstance(func_name, str):
raise TypeError(f"func_name must be of type str, not {type(func_name)}")
return LibraryCompilationResult.new(self.output_dir_path, func_name)
return LibraryCompilationResult.new(self.output_dir_path)
def load_client_parameters(
self, library_compilation_result: LibraryCompilationResult
@@ -217,7 +212,7 @@ class LibrarySupport(WrapperCpp):
def load_compilation_feedback(
self, compilation_result: LibraryCompilationResult
) -> CompilationFeedback:
) -> ProgramCompilationFeedback:
"""Load the compilation feedback from the compilation result.
Args:
@@ -227,13 +222,13 @@ class LibrarySupport(WrapperCpp):
TypeError: if compilation_result is not of type LibraryCompilationResult
Returns:
CompilationFeedback: the compilation feedback for the compiled program
ProgramCompilationFeedback: the compilation feedback for the compiled program
"""
if not isinstance(compilation_result, LibraryCompilationResult):
raise TypeError(
f"compilation_result must be of type LibraryCompilationResult, not {type(compilation_result)}"
)
return CompilationFeedback.wrap(
return ProgramCompilationFeedback.wrap(
self.cpp().load_compilation_feedback(compilation_result.cpp())
)
@@ -241,14 +236,18 @@ class LibrarySupport(WrapperCpp):
self,
library_compilation_result: LibraryCompilationResult,
simulation: bool,
circuit_name: str = "main",
) -> LibraryLambda:
"""Load the server lambda from the library compilation result.
"""Load the server lambda for a given circuit from the library compilation result.
Args:
library_compilation_result (LibraryCompilationResult): compilation result of the library
simulation (bool): use simulation for execution
circuit_name (str): name of the circuit to be loaded
Raises:
TypeError: if library_compilation_result is not of type LibraryCompilationResult
TypeError: if library_compilation_result is not of type LibraryCompilationResult, if
circuit_name is not of type str or
Returns:
LibraryLambda: executable reference to the library
@@ -258,8 +257,18 @@ class LibrarySupport(WrapperCpp):
f"library_compilation_result must be of type LibraryCompilationResult, not "
f"{type(library_compilation_result)}"
)
if not isinstance(circuit_name, str):
raise TypeError(
f"circuit_name must be of type str, not " f"{type(circuit_name)}"
)
if not isinstance(simulation, bool):
raise TypeError(
f"simulation must be of type bool, not " f"{type(simulation)}"
)
return LibraryLambda.wrap(
self.cpp().load_server_lambda(library_compilation_result.cpp(), simulation)
self.cpp().load_server_lambda(
library_compilation_result.cpp(), circuit_name, simulation
)
)
def server_call(

View File

@@ -41,12 +41,12 @@ class SimulatedValueDecrypter(WrapperCpp):
@staticmethod
# pylint: disable=arguments-differ
def new(client_parameters: ClientParameters):
def new(client_parameters: ClientParameters, circuit_name: str = "main"):
"""
Create a value decrypter.
"""
return SimulatedValueDecrypter(
_SimulatedValueDecrypter.create(client_parameters.cpp())
_SimulatedValueDecrypter.create(client_parameters.cpp(), circuit_name)
)
def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]:

View File

@@ -40,12 +40,14 @@ class SimulatedValueExporter(WrapperCpp):
@staticmethod
# pylint: disable=arguments-differ
def new(client_parameters: ClientParameters) -> "SimulatedValueExporter":
def new(
client_parameters: ClientParameters, circuitName: str = "main"
) -> "SimulatedValueExporter":
"""
Create a value exporter.
"""
return SimulatedValueExporter(
_SimulatedValueExporter.create(client_parameters.cpp())
_SimulatedValueExporter.create(client_parameters.cpp(), circuitName)
)
def export_scalar(self, position: int, value: int) -> Value:

View File

@@ -42,12 +42,14 @@ class ValueDecrypter(WrapperCpp):
@staticmethod
# pylint: disable=arguments-differ
def new(keyset: KeySet, client_parameters: ClientParameters):
def new(
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main"
):
"""
Create a value decrypter.
"""
return ValueDecrypter(
_ValueDecrypter.create(keyset.cpp(), client_parameters.cpp())
_ValueDecrypter.create(keyset.cpp(), client_parameters.cpp(), circuit_name)
)
def decrypt(self, position: int, value: Value) -> Union[int, np.ndarray]:

View File

@@ -41,12 +41,14 @@ class ValueExporter(WrapperCpp):
@staticmethod
# pylint: disable=arguments-differ
def new(keyset: KeySet, client_parameters: ClientParameters) -> "ValueExporter":
def new(
keyset: KeySet, client_parameters: ClientParameters, circuit_name: str = "main"
) -> "ValueExporter":
"""
Create a value exporter.
"""
return ValueExporter(
_ValueExporter.create(keyset.cpp(), client_parameters.cpp())
_ValueExporter.create(keyset.cpp(), client_parameters.cpp(), circuit_name)
)
def export_scalar(self, position: int, value: int) -> Value:

View File

@@ -3,6 +3,7 @@
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
#include <concretelang/Support/logging.h>
#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/SCF/IR/SCF.h>
#include <mlir/IR/BuiltinOps.h>
@@ -66,34 +67,48 @@ namespace Concrete {
struct MemoryUsagePass
: public PassWrapper<MemoryUsagePass, OperationPass<ModuleOp>> {
CompilationFeedback &feedback;
ProgramCompilationFeedback &feedback;
CircuitCompilationFeedback *circuitFeedback;
MemoryUsagePass(CompilationFeedback &feedback) : feedback{feedback} {};
MemoryUsagePass(ProgramCompilationFeedback &feedback)
: feedback{feedback}, circuitFeedback{nullptr} {};
void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
auto module = getOperation();
auto funcs = module.getOps<mlir::func::FuncOp>();
for (CircuitCompilationFeedback &circuitFeedback :
feedback.circuitFeedbacks) {
auto funcOp = llvm::find_if(funcs, [&](mlir::func::FuncOp op) {
return op.getName() == circuitFeedback.name;
});
assert(funcOp != funcs.end());
this->circuitFeedback = &circuitFeedback;
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
}
if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
return WalkResult::advance();
});
if (walk.wasInterrupted()) {
signalPassFailure();
if (walk.wasInterrupted()) {
signalPassFailure();
return;
}
}
}
@@ -172,7 +187,7 @@ struct MemoryUsagePass
// element_size
auto memoryUsage = numberOfAlloc * maybeBufferSize.value();
pass.feedback.memoryUsagePerLoc[location] += memoryUsage;
pass.circuitFeedback->memoryUsagePerLoc[location] += memoryUsage;
return std::nullopt;
}
@@ -218,7 +233,7 @@ struct MemoryUsagePass
}
auto bufferSize = maybeBufferSize.value();
pass.feedback.memoryUsagePerLoc[location] += bufferSize;
pass.circuitFeedback->memoryUsagePerLoc[location] += bufferSize;
}
}
@@ -233,7 +248,7 @@ struct MemoryUsagePass
} // namespace Concrete
std::unique_ptr<OperationPass<ModuleOp>>
createMemoryUsagePass(CompilationFeedback &feedback) {
createMemoryUsagePass(ProgramCompilationFeedback &feedback) {
return std::make_unique<Concrete::MemoryUsagePass>(feedback);
}

View File

@@ -1,3 +1,4 @@
#include "concretelang/Support/CompilationFeedback.h"
#include <concretelang/Analysis/Utils.h>
#include <concretelang/Dialect/TFHE/Analysis/ExtractStatistics.h>
@@ -34,35 +35,48 @@ namespace TFHE {
struct ExtractTFHEStatisticsPass
: public PassWrapper<ExtractTFHEStatisticsPass, OperationPass<ModuleOp>> {
CompilationFeedback &feedback;
ProgramCompilationFeedback &feedback;
CircuitCompilationFeedback *circuitFeedback;
ExtractTFHEStatisticsPass(CompilationFeedback &feedback)
: feedback{feedback} {};
ExtractTFHEStatisticsPass(ProgramCompilationFeedback &feedback)
: feedback{feedback}, circuitFeedback{nullptr} {};
void runOnOperation() override {
WalkResult walk =
getOperation()->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
auto module = getOperation();
auto funcs = module.getOps<mlir::func::FuncOp>();
for (CircuitCompilationFeedback &circuitFeedback :
feedback.circuitFeedbacks) {
auto funcOp = llvm::find_if(funcs, [&](mlir::func::FuncOp op) {
return op.getName() == circuitFeedback.name;
});
assert(funcOp != funcs.end());
this->circuitFeedback = &circuitFeedback;
WalkResult walk =
(*funcOp)->walk([&](Operation *op, const WalkStage &stage) {
if (stage.isBeforeAllRegions()) {
std::optional<StringError> error = this->enter(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
}
if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
if (stage.isAfterAllRegions()) {
std::optional<StringError> error = this->exit(op);
if (error.has_value()) {
op->emitError() << error->mesg;
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
return WalkResult::advance();
});
if (walk.wasInterrupted()) {
signalPassFailure();
if (walk.wasInterrupted()) {
signalPassFailure();
return;
}
}
}
@@ -118,14 +132,14 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::ENCRYPTED_ADDITION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
@@ -145,14 +159,14 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::CLEAR_ADDITION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
@@ -172,14 +186,14 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::PBS;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
@@ -199,14 +213,14 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::KEY_SWITCH;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
@@ -226,14 +240,14 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
@@ -253,14 +267,14 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
@@ -279,17 +293,17 @@ struct ExtractTFHEStatisticsPass
auto resultingKey = op.getType().getKey().getNormalized();
auto location = locationString(op.getLoc());
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::SECRET, (size_t)resultingKey->index);
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::SECRET, (int64_t)resultingKey->index);
keys.push_back(key);
// clear - encrypted = clear + neg(encrypted)
auto operation = PrimitiveOperation::ENCRYPTED_NEGATION;
pass.feedback.statistics.push_back(concretelang::Statistic{
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
@@ -297,7 +311,7 @@ struct ExtractTFHEStatisticsPass
});
operation = PrimitiveOperation::CLEAR_ADDITION;
pass.feedback.statistics.push_back(concretelang::Statistic{
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
@@ -319,20 +333,20 @@ struct ExtractTFHEStatisticsPass
auto location = locationString(op.getLoc());
auto operation = PrimitiveOperation::WOP_PBS;
auto keys = std::vector<std::pair<KeyType, size_t>>();
auto keys = std::vector<std::pair<KeyType, int64_t>>();
auto count = pass.iterations;
std::pair<KeyType, size_t> key =
std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex());
std::pair<KeyType, int64_t> key =
std::make_pair(KeyType::BOOTSTRAP, (int64_t)bsk.getIndex());
keys.push_back(key);
key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex());
key = std::make_pair(KeyType::KEY_SWITCH, (int64_t)ksk.getIndex());
keys.push_back(key);
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex());
key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (int64_t)pksk.getIndex());
keys.push_back(key);
pass.feedback.statistics.push_back(concretelang::Statistic{
pass.circuitFeedback->statistics.push_back(concretelang::Statistic{
location,
operation,
keys,
@@ -342,13 +356,13 @@ struct ExtractTFHEStatisticsPass
return std::nullopt;
}
size_t iterations = 1;
int64_t iterations = 1;
};
} // namespace TFHE
std::unique_ptr<OperationPass<ModuleOp>>
createStatisticExtractionPass(CompilationFeedback &feedback) {
createStatisticExtractionPass(ProgramCompilationFeedback &feedback) {
return std::make_unique<TFHE::ExtractTFHEStatisticsPass>(feedback);
}

View File

@@ -5,7 +5,9 @@
#include <cassert>
#include <fstream>
#include <vector>
#include "concrete-protocol.capnp.h"
#include "concretelang/Support/CompilationFeedback.h"
using concretelang::protocol::Message;
@@ -13,60 +15,11 @@ using concretelang::protocol::Message;
namespace mlir {
namespace concretelang {
void CompilationFeedback::fillFromProgramInfo(
const Message<concreteprotocol::ProgramInfo> &programInfo) {
auto params = programInfo.asReader();
// Compute the size of secret keys
totalSecretKeysSize = 0;
for (auto skInfo : params.getKeyset().getLweSecretKeys()) {
assert(skInfo.getParams().getIntegerPrecision() % 8 == 0);
auto byteSize = skInfo.getParams().getIntegerPrecision() / 8;
totalSecretKeysSize += skInfo.getParams().getLweDimension() * byteSize;
}
// Compute the boostrap keys size
totalBootstrapKeysSize = 0;
for (auto bskInfo : params.getKeyset().getLweBootstrapKeys()) {
assert(bskInfo.getInputId() <
(uint32_t)params.getKeyset().getLweSecretKeys().size());
auto inputKeyInfo =
params.getKeyset().getLweSecretKeys()[bskInfo.getInputId()];
assert(bskInfo.getOutputId() <
(uint32_t)params.getKeyset().getLweSecretKeys().size());
auto outputKeyInfo =
params.getKeyset().getLweSecretKeys()[bskInfo.getOutputId()];
assert(bskInfo.getParams().getIntegerPrecision() % 8 == 0);
auto byteSize = bskInfo.getParams().getIntegerPrecision() % 8;
auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1;
auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1;
auto level = bskInfo.getParams().getLevelCount();
auto glweDimension = bskInfo.getParams().getGlweDimension();
totalBootstrapKeysSize += inputLweSize * level * (glweDimension + 1) *
(glweDimension + 1) * outputLweSize * byteSize;
}
// Compute the keyswitch keys size
totalKeyswitchKeysSize = 0;
for (auto kskInfo : params.getKeyset().getLweKeyswitchKeys()) {
assert(kskInfo.getInputId() <
(uint32_t)params.getKeyset().getLweSecretKeys().size());
auto inputKeyInfo =
params.getKeyset().getLweSecretKeys()[kskInfo.getInputId()];
assert(kskInfo.getOutputId() <
(uint32_t)params.getKeyset().getLweSecretKeys().size());
auto outputKeyInfo =
params.getKeyset().getLweSecretKeys()[kskInfo.getOutputId()];
assert(kskInfo.getParams().getIntegerPrecision() % 8 == 0);
auto byteSize = kskInfo.getParams().getIntegerPrecision() % 8;
auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1;
auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1;
auto level = kskInfo.getParams().getLevelCount();
totalKeyswitchKeysSize += level * inputLweSize * outputLweSize * byteSize;
}
auto circuitInfo = params.getCircuits()[0];
void CircuitCompilationFeedback::fillFromCircuitInfo(
concreteprotocol::CircuitInfo::Reader circuitInfo) {
auto computeGateSize =
[&](const Message<concreteprotocol::GateInfo> &gateInfo) {
unsigned int nElements = 1;
// TODO: CHANGE THAT ITS WRONG
for (auto dimension :
gateInfo.asReader().getRawInfo().getShape().getDimensions()) {
nElements *= dimension;
@@ -104,17 +57,77 @@ void CompilationFeedback::fillFromProgramInfo(
}
}
}
// Sets name
name = circuitInfo.getName().cStr();
}
outcome::checked<CompilationFeedback, StringError>
CompilationFeedback::load(std::string jsonPath) {
void ProgramCompilationFeedback::fillFromProgramInfo(
const Message<concreteprotocol::ProgramInfo> &programInfo) {
auto params = programInfo.asReader();
// Compute the size of secret keys
totalSecretKeysSize = 0;
for (auto skInfo : params.getKeyset().getLweSecretKeys()) {
assert(skInfo.getParams().getIntegerPrecision() % 8 == 0);
auto byteSize = skInfo.getParams().getIntegerPrecision() / 8;
totalSecretKeysSize += skInfo.getParams().getLweDimension() * byteSize;
}
// Compute the boostrap keys size
totalBootstrapKeysSize = 0;
for (auto bskInfo : params.getKeyset().getLweBootstrapKeys()) {
assert(bskInfo.getInputId() <
(uint32_t)params.getKeyset().getLweSecretKeys().size());
auto inputKeyInfo =
params.getKeyset().getLweSecretKeys()[bskInfo.getInputId()];
assert(bskInfo.getOutputId() <
(uint32_t)params.getKeyset().getLweSecretKeys().size());
auto outputKeyInfo =
params.getKeyset().getLweSecretKeys()[bskInfo.getOutputId()];
assert(bskInfo.getParams().getIntegerPrecision() % 8 == 0);
auto byteSize = bskInfo.getParams().getIntegerPrecision() / 8;
auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1;
auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1;
auto level = bskInfo.getParams().getLevelCount();
auto glweDimension = bskInfo.getParams().getGlweDimension();
totalBootstrapKeysSize += inputLweSize * level * (glweDimension + 1) *
(glweDimension + 1) * outputLweSize * byteSize;
}
// Compute the keyswitch keys size
totalKeyswitchKeysSize = 0;
for (auto kskInfo : params.getKeyset().getLweKeyswitchKeys()) {
assert(kskInfo.getInputId() <
(uint32_t)params.getKeyset().getLweSecretKeys().size());
auto inputKeyInfo =
params.getKeyset().getLweSecretKeys()[kskInfo.getInputId()];
assert(kskInfo.getOutputId() <
(uint32_t)params.getKeyset().getLweSecretKeys().size());
auto outputKeyInfo =
params.getKeyset().getLweSecretKeys()[kskInfo.getOutputId()];
assert(kskInfo.getParams().getIntegerPrecision() % 8 == 0);
auto byteSize = kskInfo.getParams().getIntegerPrecision() / 8;
auto inputLweSize = inputKeyInfo.getParams().getLweDimension() + 1;
auto outputLweSize = outputKeyInfo.getParams().getLweDimension() + 1;
auto level = kskInfo.getParams().getLevelCount();
totalKeyswitchKeysSize += level * inputLweSize * outputLweSize * byteSize;
}
// Compute the circuit feedbacks
for (auto circuitInfo : params.getCircuits()) {
CircuitCompilationFeedback feedback;
feedback.fillFromCircuitInfo(circuitInfo);
circuitFeedbacks.push_back(feedback);
}
}
outcome::checked<ProgramCompilationFeedback, StringError>
ProgramCompilationFeedback::load(std::string jsonPath) {
std::ifstream file(jsonPath);
std::string content((std::istreambuf_iterator<char>(file)),
(std::istreambuf_iterator<char>()));
if (file.fail()) {
return StringError("Cannot read file: ") << jsonPath;
}
auto expectedCompFeedback = llvm::json::parse<CompilationFeedback>(content);
auto expectedCompFeedback =
llvm::json::parse<ProgramCompilationFeedback>(content);
if (auto err = expectedCompFeedback.takeError()) {
return StringError("Cannot open compilation feedback: ")
<< llvm::toString(std::move(err)) << "\n"
@@ -123,196 +136,228 @@ CompilationFeedback::load(std::string jsonPath) {
return expectedCompFeedback.get();
}
llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) {
llvm::json::Object object{
{"complexity", v.complexity},
{"pError", v.pError},
{"globalPError", v.globalPError},
{"totalSecretKeysSize", v.totalSecretKeysSize},
{"totalBootstrapKeysSize", v.totalBootstrapKeysSize},
{"totalKeyswitchKeysSize", v.totalKeyswitchKeysSize},
{"totalInputsSize", v.totalInputsSize},
{"totalOutputsSize", v.totalOutputsSize},
{"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs},
};
auto memoryUsageObject = llvm::json::Object();
for (auto key : v.memoryUsagePerLoc) {
memoryUsageObject.insert({key.first, key.second});
llvm::json::Object
memoryUsageToJson(const std::map<std::string, int64_t> &memoryUsagePerLoc) {
auto object = llvm::json::Object();
for (auto key : memoryUsagePerLoc) {
object.insert({key.first, key.second});
}
object.insert({"memoryUsagePerLoc", std::move(memoryUsageObject)});
auto statisticsJson = llvm::json::Array();
for (auto statistic : v.statistics) {
auto statisticJson = llvm::json::Object();
statisticJson.insert({"location", statistic.location});
switch (statistic.operation) {
case PrimitiveOperation::PBS:
statisticJson.insert({"operation", "PBS"});
break;
case PrimitiveOperation::WOP_PBS:
statisticJson.insert({"operation", "WOP_PBS"});
break;
case PrimitiveOperation::KEY_SWITCH:
statisticJson.insert({"operation", "KEY_SWITCH"});
break;
case PrimitiveOperation::CLEAR_ADDITION:
statisticJson.insert({"operation", "CLEAR_ADDITION"});
break;
case PrimitiveOperation::ENCRYPTED_ADDITION:
statisticJson.insert({"operation", "ENCRYPTED_ADDITION"});
break;
case PrimitiveOperation::CLEAR_MULTIPLICATION:
statisticJson.insert({"operation", "CLEAR_MULTIPLICATION"});
break;
case PrimitiveOperation::ENCRYPTED_NEGATION:
statisticJson.insert({"operation", "ENCRYPTED_NEGATION"});
break;
}
auto keysJson = llvm::json::Array();
for (auto &key : statistic.keys) {
KeyType type = key.first;
size_t index = key.second;
auto keyJson = llvm::json::Array();
switch (type) {
case KeyType::SECRET:
keyJson.push_back("SECRET");
break;
case KeyType::BOOTSTRAP:
keyJson.push_back("BOOTSTRAP");
break;
case KeyType::KEY_SWITCH:
keyJson.push_back("KEY_SWITCH");
break;
case KeyType::PACKING_KEY_SWITCH:
keyJson.push_back("PACKING_KEY_SWITCH");
break;
}
keyJson.push_back((int64_t)index);
keysJson.push_back(std::move(keyJson));
}
statisticJson.insert({"keys", std::move(keysJson)});
statisticJson.insert({"count", (int64_t)statistic.count});
statisticsJson.push_back(std::move(statisticJson));
}
object.insert({"statistics", std::move(statisticsJson)});
return object;
}
llvm::json::Object statisticToJson(const Statistic &statistic) {
auto object = llvm::json::Object();
object.insert({"location", statistic.location});
object.insert({"count", statistic.count});
switch (statistic.operation) {
case PrimitiveOperation::PBS:
object.insert({"operation", "PBS"});
break;
case PrimitiveOperation::WOP_PBS:
object.insert({"operation", "WOP_PBS"});
break;
case PrimitiveOperation::KEY_SWITCH:
object.insert({"operation", "KEY_SWITCH"});
break;
case PrimitiveOperation::CLEAR_ADDITION:
object.insert({"operation", "CLEAR_ADDITION"});
break;
case PrimitiveOperation::ENCRYPTED_ADDITION:
object.insert({"operation", "ENCRYPTED_ADDITION"});
break;
case PrimitiveOperation::CLEAR_MULTIPLICATION:
object.insert({"operation", "CLEAR_MULTIPLICATION"});
break;
case PrimitiveOperation::ENCRYPTED_NEGATION:
object.insert({"operation", "ENCRYPTED_NEGATION"});
break;
}
auto keysJson = llvm::json::Array();
for (auto &key : statistic.keys) {
KeyType type = key.first;
size_t index = key.second;
auto keyJson = llvm::json::Array();
switch (type) {
case KeyType::SECRET:
keyJson.push_back("SECRET");
break;
case KeyType::BOOTSTRAP:
keyJson.push_back("BOOTSTRAP");
break;
case KeyType::KEY_SWITCH:
keyJson.push_back("KEY_SWITCH");
break;
case KeyType::PACKING_KEY_SWITCH:
keyJson.push_back("PACKING_KEY_SWITCH");
break;
}
keyJson.push_back((int64_t)index);
keysJson.push_back(std::move(keyJson));
}
object.insert({"keys", std::move(keysJson)});
return object;
}
llvm::json::Array statisticsToJson(const std::vector<Statistic> &statistics) {
auto object = llvm::json::Array();
for (auto statistic : statistics) {
object.push_back(statisticToJson(statistic));
}
return object;
}
llvm::json::Array crtDecompositionToJson(
const std::vector<std::vector<int64_t>> &crtDecompositionsOfOutputs) {
auto object = llvm::json::Array();
for (auto crtDec : crtDecompositionsOfOutputs) {
auto inner = llvm::json::Array();
for (auto val : crtDec) {
inner.push_back(val);
}
object.push_back(std::move(inner));
}
return object;
}
llvm::json::Array circuitFeedbacksToJson(
const std::vector<CircuitCompilationFeedback> &circuitFeedbacks) {
auto object = llvm::json::Array();
for (auto circuit : circuitFeedbacks) {
llvm::json::Object circuitObject{
{"name", circuit.name},
{"totalInputsSize", circuit.totalInputsSize},
{"totalOutputsSize", circuit.totalOutputsSize},
{"crtDecompositionsOfOutputs",
crtDecompositionToJson(circuit.crtDecompositionsOfOutputs)},
{"statistics", statisticsToJson(circuit.statistics)},
{"memoryUsagePerLoc", memoryUsageToJson(circuit.memoryUsagePerLoc)},
};
object.push_back(std::move(circuitObject));
}
return object;
}
llvm::json::Value
toJSON(const mlir::concretelang::ProgramCompilationFeedback &program) {
llvm::json::Object programObject{
{"complexity", program.complexity},
{"pError", program.pError},
{"globalPError", program.globalPError},
{"totalSecretKeysSize", program.totalSecretKeysSize},
{"totalBootstrapKeysSize", program.totalBootstrapKeysSize},
{"totalKeyswitchKeysSize", program.totalKeyswitchKeysSize},
{"circuitFeedbacks", circuitFeedbacksToJson(program.circuitFeedbacks)}};
return programObject;
}
template <typename K, typename V>
bool fromJSON(const llvm::json::Value &j, std::pair<K, V> &v,
llvm::json::Path p) {
if (auto *array = j.getAsArray()) {
if (!fromJSON((*array)[0], v.first, p.index(0)))
return false;
if (!fromJSON((*array)[1], v.second, p.index(1)))
return false;
return true;
}
p.report("expected array");
return false;
}
bool fromJSON(const llvm::json::Value j,
mlir::concretelang::CompilationFeedback &v, llvm::json::Path p) {
mlir::concretelang::PrimitiveOperation &v, llvm::json::Path p) {
if (auto operationString = j.getAsString()) {
if (operationString == "PBS") {
v = PrimitiveOperation::PBS;
return true;
} else if (operationString == "KEY_SWITCH") {
v = PrimitiveOperation::KEY_SWITCH;
return true;
} else if (operationString == "WOP_PBS") {
v = PrimitiveOperation::WOP_PBS;
return true;
} else if (operationString == "CLEAR_ADDITION") {
v = PrimitiveOperation::CLEAR_ADDITION;
return true;
} else if (operationString == "ENCRYPTED_ADDITION") {
v = PrimitiveOperation::ENCRYPTED_ADDITION;
return true;
} else if (operationString == "CLEAR_MULTIPLICATION") {
v = PrimitiveOperation::CLEAR_MULTIPLICATION;
return true;
} else if (operationString == "ENCRYPTED_NEGATION") {
v = PrimitiveOperation::ENCRYPTED_NEGATION;
return true;
} else {
p.report("expected one of "
"(PBS|KEY_SWITCH|WOP_PBS|CLEAR_ADDITION|ENCRYPTED_ADDITION|"
"CLEAR_MULTIPLICATION|ENCRYPTED_NEGATION)");
return false;
}
}
p.report("expected string");
return false;
}
bool fromJSON(const llvm::json::Value j, mlir::concretelang::KeyType &v,
llvm::json::Path p) {
if (auto keyTypeString = j.getAsString()) {
if (keyTypeString == "SECRET") {
v = KeyType::SECRET;
return true;
} else if (keyTypeString == "BOOTSTRAP") {
v = KeyType::BOOTSTRAP;
return true;
} else if (keyTypeString == "KEY_SWITCH") {
v = KeyType::KEY_SWITCH;
return true;
} else if (keyTypeString == "PACKING_KEY_SWITCH") {
v = KeyType::PACKING_KEY_SWITCH;
return true;
} else {
p.report(
"expected one of (SECRET|BOOTSTRAP|KEY_SWITCH|PACKING_KEY_SWITCH)");
return false;
}
}
p.report("expected string");
return false;
}
bool fromJSON(const llvm::json::Value j, mlir::concretelang::Statistic &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
bool is_success =
O && O.map("complexity", v.complexity) && O.map("pError", v.pError) &&
O.map("globalPError", v.globalPError) &&
O.map("totalSecretKeysSize", v.totalSecretKeysSize) &&
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
O.map("totalInputsSize", v.totalInputsSize) &&
O.map("totalOutputsSize", v.totalOutputsSize) &&
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs);
return O && O.map("location", v.location) &&
O.map("operation", v.operation) && O.map("operation", v.operation) &&
O.map("keys", v.keys) && O.map("count", v.count);
}
if (!is_success) {
return false;
}
bool fromJSON(const llvm::json::Value j,
mlir::concretelang::CircuitCompilationFeedback &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("name", v.name) &&
O.map("totalInputsSize", v.totalInputsSize) &&
O.map("totalOutputsSize", v.totalOutputsSize) &&
O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs) &&
O.map("statistics", v.statistics) &&
O.map("memoryUsagePerLoc", v.memoryUsagePerLoc);
}
auto object = j.getAsObject();
if (!object) {
return false;
}
bool fromJSON(const llvm::json::Value j,
mlir::concretelang::ProgramCompilationFeedback &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
auto memoryUsageObject = object->getObject("memoryUsagePerLoc");
if (!memoryUsageObject) {
return false;
}
for (auto entry : *memoryUsageObject) {
auto loc = entry.getFirst().str();
auto maybeUsage = entry.getSecond().getAsInteger();
if (!maybeUsage.has_value()) {
return false;
}
v.memoryUsagePerLoc[loc] = *maybeUsage;
}
auto statistics = object->getArray("statistics");
if (!statistics) {
return false;
}
for (auto statisticValue : *statistics) {
auto statistic = statisticValue.getAsObject();
if (!statistic) {
return false;
}
auto location = statistic->getString("location");
auto operationStr = statistic->getString("operation");
auto keysArray = statistic->getArray("keys");
auto count = statistic->getInteger("count");
if (!operationStr || !location || !keysArray || !count) {
return false;
}
PrimitiveOperation operation;
if (operationStr.value() == "PBS") {
operation = PrimitiveOperation::PBS;
} else if (operationStr.value() == "KEY_SWITCH") {
operation = PrimitiveOperation::KEY_SWITCH;
} else if (operationStr.value() == "WOP_PBS") {
operation = PrimitiveOperation::WOP_PBS;
} else if (operationStr.value() == "CLEAR_ADDITION") {
operation = PrimitiveOperation::CLEAR_ADDITION;
} else if (operationStr.value() == "ENCRYPTED_ADDITION") {
operation = PrimitiveOperation::ENCRYPTED_ADDITION;
} else if (operationStr.value() == "CLEAR_MULTIPLICATION") {
operation = PrimitiveOperation::CLEAR_MULTIPLICATION;
} else if (operationStr.value() == "ENCRYPTED_NEGATION") {
operation = PrimitiveOperation::ENCRYPTED_NEGATION;
} else {
return false;
}
auto keys = std::vector<std::pair<KeyType, size_t>>();
for (auto keyValue : *keysArray) {
llvm::json::Array *keyArray = keyValue.getAsArray();
if (!keyArray || keyArray->size() != 2) {
return false;
}
auto typeStr = keyArray->front().getAsString();
auto index = keyArray->back().getAsInteger();
if (!typeStr || !index) {
return false;
}
KeyType type;
if (typeStr.value() == "SECRET") {
type = KeyType::SECRET;
} else if (typeStr.value() == "BOOTSTRAP") {
type = KeyType::BOOTSTRAP;
} else if (typeStr.value() == "KEY_SWITCH") {
type = KeyType::KEY_SWITCH;
} else if (typeStr.value() == "PACKING_KEY_SWITCH") {
type = KeyType::PACKING_KEY_SWITCH;
} else {
return false;
}
keys.push_back(std::make_pair(type, (size_t)*index));
}
v.statistics.push_back(
Statistic{location->str(), operation, keys, (uint64_t)*count});
}
return true;
return O && O.map("complexity", v.complexity) && O.map("pError", v.pError) &&
O.map("globalPError", v.globalPError) &&
O.map("totalSecretKeysSize", v.totalSecretKeysSize) &&
O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) &&
O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) &&
O.map("circuitFeedbacks", v.circuitFeedbacks);
}
} // namespace concretelang

View File

@@ -3,6 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Support/V0Parameters.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
@@ -166,19 +167,33 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
if (descriptions->empty()) { // The pass has not been run
return std::nullopt;
}
if (this->compilerOptions.mainFuncName.has_value()) {
auto name = this->compilerOptions.mainFuncName.value();
auto description = descriptions->find(name);
if (description == descriptions->end()) {
std::string names;
return StreamStringError("Function not found, name='")
<< name << "', cannot get optimizer description";
}
return std::move(description->second);
if (descriptions->size() > 1 &&
config.strategy !=
mlir::concretelang::optimizer::V0) { // Multi circuits without V0
return StreamStringError(
"Multi-circuits is only supported for V0 optimization.");
}
if (descriptions->size() != 1) {
llvm::errs() << "Several crypto parameters exists: the function need to be "
"specified, taking the first one";
if (descriptions->size() > 1) {
auto iter = descriptions->begin();
auto desc = std::move(iter->second);
if (!desc.has_value()) {
return StreamStringError("Expected description.");
}
if (!desc.value().dag.has_value()) {
return StreamStringError("Expected dag in description.");
}
iter++;
while (iter != descriptions->end()) {
if (!iter->second.has_value()) {
return StreamStringError("Expected description.");
}
if (!iter->second.value().dag.has_value()) {
return StreamStringError("Expected dag in description.");
}
desc->dag.value()->concat(*iter->second.value().dag.value());
iter++;
}
return std::move(desc);
}
return std::move(descriptions->begin()->second);
}
@@ -199,7 +214,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
res.fheContext.emplace(
mlir::concretelang::V0FHEContext{constraint, v0Params});
CompilationFeedback feedback;
ProgramCompilationFeedback feedback;
res.feedback.emplace(feedback);
return llvm::Error::success();
@@ -213,7 +228,7 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
if (!descr.get().has_value()) {
return llvm::Error::success();
}
CompilationFeedback feedback;
ProgramCompilationFeedback feedback;
// Make sure to use the gpu constraint of the optimizer if we use gpu
// backend.
compilerOptions.optimizerConfig.use_gpu_constraints =
@@ -322,9 +337,8 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
// on the `FHE` dialect.
if ((this->generateProgramInfo || target == Target::LIBRARY) &&
!options.encodings) {
auto funcName = options.mainFuncName.value_or("main");
auto encodingInfosOrErr =
mlir::concretelang::encodings::getCircuitEncodings(funcName, module);
mlir::concretelang::encodings::getProgramEncoding(module);
if (!encodingInfosOrErr) {
return encodingInfosOrErr.takeError();
}
@@ -363,7 +377,7 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
chunkedMode.asBuilder().setWidth(options.chunkWidth);
maybeChunkInfo = chunkedMode;
}
mlir::concretelang::encodings::setCircuitEncodingModes(
mlir::concretelang::encodings::setProgramEncodingModes(
*options.encodings, maybeChunkInfo, res.fheContext);
}
@@ -457,41 +471,29 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,
// Generate client parameters if requested
if (this->generateProgramInfo) {
if (!options.mainFuncName.has_value()) {
return StreamStringError(
"Generation of client parameters requested, but no function name "
"specified");
}
if (!res.fheContext.has_value()) {
return StreamStringError(
"Cannot generate client parameters, the fhe context is empty for " +
options.mainFuncName.value());
"Cannot generate client parameters, the fhe context is empty");
}
}
// Generate program info if requested
if (this->generateProgramInfo || target == Target::LIBRARY) {
auto funcName = options.mainFuncName.value_or("main");
if (!res.fheContext.has_value()) {
// Some tests involve call a to non encrypted functions
auto programInfo = Message<concreteprotocol::ProgramInfo>();
programInfo.asBuilder().initCircuits(1);
programInfo.asBuilder().getCircuits()[0].setName(std::string(funcName));
programInfo.asBuilder().getCircuits()[0].setName(std::string("main"));
res.programInfo = programInfo;
} else {
auto programInfoOrErr =
mlir::concretelang::createProgramInfoFromTfheDialect(
module, funcName, options.optimizerConfig.security,
module, options.optimizerConfig.security,
options.encodings.value(), options.compressEvaluationKeys);
if (!programInfoOrErr)
return programInfoOrErr.takeError();
res.programInfo = std::move(*programInfoOrErr);
// If more than one circuit, feedback can not be generated for now ..
if (res.programInfo->asReader().getCircuits().size() != 1) {
return StreamStringError(
"Cannot generate feedback for program with more than one circuit.");
}
res.feedback->fillFromProgramInfo(*res.programInfo);
}
}

View File

@@ -16,6 +16,7 @@
#include <memory>
#include <optional>
#include <variant>
#include <vector>
namespace FHE = mlir::concretelang::FHE;
using concretelang::protocol::Message;
@@ -67,20 +68,13 @@ encodingFromType(mlir::Type ty) {
}
llvm::Expected<Message<concreteprotocol::CircuitEncodingInfo>>
getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module) {
// Find the input function
auto rangeOps = module.getOps<mlir::func::FuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
return op.getName() == functionName;
});
if (funcOp == rangeOps.end()) {
return StreamStringError("Function not found, name='")
<< functionName << "', cannot get circuit encodings";
}
auto funcType = (*funcOp).getFunctionType();
getCircuitEncodings(mlir::func::FuncOp funcOp) {
auto funcType = funcOp.getFunctionType();
// Retrieve input/output encodings
auto circuitEncodings = Message<concreteprotocol::CircuitEncodingInfo>();
circuitEncodings.asBuilder().setName(funcOp.getSymName().str());
auto inputsBuilder =
circuitEncodings.asBuilder().initInputs(funcType.getNumInputs());
for (size_t i = 0; i < funcType.getNumInputs(); i++) {
@@ -105,8 +99,32 @@ getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module) {
return std::move(circuitEncodings);
}
llvm::Expected<Message<concreteprotocol::ProgramEncodingInfo>>
getProgramEncoding(mlir::ModuleOp module) {
auto funcs = module.getOps<mlir::func::FuncOp>();
auto circuitEncodings =
std::vector<Message<concreteprotocol::CircuitEncodingInfo>>();
for (auto func : funcs) {
auto encodingInfosOrErr = getCircuitEncodings(func);
if (!encodingInfosOrErr) {
return encodingInfosOrErr.takeError();
}
circuitEncodings.push_back(*encodingInfosOrErr);
}
auto programEncoding = Message<concreteprotocol::ProgramEncodingInfo>();
auto circuitBuilder =
programEncoding.asBuilder().initCircuits(circuitEncodings.size());
for (size_t i = 0; i < circuitEncodings.size(); i++) {
circuitBuilder.setWithCaveats(i, circuitEncodings[i].asReader());
}
return std::move(programEncoding);
}
void setCircuitEncodingModes(
Message<concreteprotocol::CircuitEncodingInfo> &info,
concreteprotocol::CircuitEncodingInfo::Builder info,
std::optional<
Message<concreteprotocol::IntegerCiphertextEncodingInfo::ChunkedMode>>
maybeChunk,
@@ -164,13 +182,25 @@ void setCircuitEncodingModes(
// Got nothing particular. Setting encoding mode to native.
integerEncodingBuilder.getMode().initNative();
};
for (auto encInfoBuilder : info.asBuilder().getInputs()) {
for (auto encInfoBuilder : info.getInputs()) {
setMode(encInfoBuilder);
}
for (auto encInfoBuilder : info.asBuilder().getOutputs()) {
for (auto encInfoBuilder : info.getOutputs()) {
setMode(encInfoBuilder);
}
}
void setProgramEncodingModes(
Message<concreteprotocol::ProgramEncodingInfo> &info,
std::optional<
Message<concreteprotocol::IntegerCiphertextEncodingInfo::ChunkedMode>>
maybeChunk,
std::optional<V0FHEContext> maybeFheContext) {
for (auto circuitInfo : info.asBuilder().getCircuits()) {
setCircuitEncodingModes(circuitInfo, maybeChunk, maybeFheContext);
}
}
} // namespace encodings
} // namespace concretelang
} // namespace mlir

View File

@@ -5,6 +5,7 @@
#include "llvm/Support/TargetSelect.h"
#include "concretelang/Support/CompilationFeedback.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
@@ -345,7 +346,7 @@ normalizeTFHEKeys(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
extractTFHEStatistics(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
CompilationFeedback &feedback) {
ProgramCompilationFeedback &feedback) {
mlir::PassManager pm(&context);
pipelinePrinting("TFHEStatistics", pm, context);
@@ -371,7 +372,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
computeMemoryUsage(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
CompilationFeedback &feedback) {
ProgramCompilationFeedback &feedback) {
mlir::PassManager pm(&context);
pipelinePrinting("Computing Memory Usage", pm, context);

View File

@@ -296,32 +296,22 @@ extractKeysetInfo(TFHE::TFHECircuitKeys circuitKeys,
}
llvm::Expected<Message<concreteprotocol::CircuitInfo>>
extractCircuitInfo(mlir::ModuleOp module, llvm::StringRef functionName,
Message<concreteprotocol::CircuitEncodingInfo> &encodings,
extractCircuitInfo(mlir::func::FuncOp funcOp,
concreteprotocol::CircuitEncodingInfo::Reader encodings,
concrete::SecurityCurve curve) {
auto output = Message<concreteprotocol::CircuitInfo>();
// Check that the specified function can be found
auto rangeOps = module.getOps<mlir::func::FuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
return op.getName() == functionName;
});
if (funcOp == rangeOps.end()) {
return StreamStringError(
"cannot find the function for generate client parameters: ")
<< functionName;
}
// Create input and output circuit gate parameters
auto funcType = (*funcOp).getFunctionType();
auto funcType = funcOp.getFunctionType();
output.asBuilder().setName(functionName.str());
output.asBuilder().setName(encodings.getName().cStr());
output.asBuilder().initInputs(funcType.getNumInputs());
output.asBuilder().initOutputs(funcType.getNumResults());
for (unsigned int i = 0; i < funcType.getNumInputs(); i++) {
auto ty = funcType.getInput(i);
auto encoding = encodings.asReader().getInputs()[i];
auto encoding = encodings.getInputs()[i];
auto maybeGate = generateGate(ty, encoding, curve);
if (!maybeGate) {
return maybeGate.takeError();
@@ -330,7 +320,7 @@ extractCircuitInfo(mlir::ModuleOp module, llvm::StringRef functionName,
}
for (unsigned int i = 0; i < funcType.getNumResults(); i++) {
auto ty = funcType.getResult(i);
auto encoding = encodings.asReader().getOutputs()[i];
auto encoding = encodings.getOutputs()[i];
auto maybeGate = generateGate(ty, encoding, curve);
if (!maybeGate) {
return maybeGate.takeError();
@@ -341,10 +331,43 @@ extractCircuitInfo(mlir::ModuleOp module, llvm::StringRef functionName,
return output;
}
llvm::Expected<Message<concreteprotocol::ProgramInfo>> extractProgramInfo(
mlir::ModuleOp module,
const Message<concreteprotocol::ProgramEncodingInfo> &encodings,
concrete::SecurityCurve curve) {
auto output = Message<concreteprotocol::ProgramInfo>();
auto circuitsCount = encodings.asReader().getCircuits().size();
auto circuitsBuilder = output.asBuilder().initCircuits(circuitsCount);
auto rangeOps = module.getOps<mlir::func::FuncOp>();
for (size_t i = 0; i < circuitsCount; i++) {
auto circuitEncoding = encodings.asReader().getCircuits()[i];
auto functionName = circuitEncoding.getName();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
return op.getName() == functionName.cStr();
});
if (funcOp == rangeOps.end()) {
return StreamStringError("cannot find the following function to generate "
"program info: ")
<< functionName.cStr();
}
auto maybeCircuitInfo = extractCircuitInfo(*funcOp, circuitEncoding, curve);
if (!maybeCircuitInfo) {
return maybeCircuitInfo.takeError();
}
circuitsBuilder.setWithCaveats(i, (*maybeCircuitInfo).asReader());
}
return output;
}
llvm::Expected<Message<concreteprotocol::ProgramInfo>>
createProgramInfoFromTfheDialect(
mlir::ModuleOp module, llvm::StringRef functionName, int bitsOfSecurity,
Message<concreteprotocol::CircuitEncodingInfo> &encodings,
mlir::ModuleOp module, int bitsOfSecurity,
const Message<concreteprotocol::ProgramEncodingInfo> &encodings,
bool compressEvaluationKeys) {
// Check that security curves exist
@@ -354,23 +377,20 @@ createProgramInfoFromTfheDialect(
<< bitsOfSecurity << "bits";
}
// Create the output Program Info.
auto output = Message<concreteprotocol::ProgramInfo>();
// We generate the circuit infos from the module.
auto maybeProgramInfo = extractProgramInfo(module, encodings, *curve);
if (!maybeProgramInfo) {
return maybeProgramInfo.takeError();
}
// Extract the output Program Info.
Message<concreteprotocol::ProgramInfo> output = *maybeProgramInfo;
// We extract the keys of the circuit
auto keysetInfo = extractKeysetInfo(TFHE::extractCircuitKeys(module), *curve,
compressEvaluationKeys);
output.asBuilder().setKeyset(keysetInfo.asReader());
// We generate the gates for the inputs aud outputs
auto maybeCircuitInfo =
extractCircuitInfo(module, functionName, encodings, *curve);
if (!maybeCircuitInfo) {
return maybeCircuitInfo.takeError();
}
output.asBuilder().initCircuits(1).setWithCaveats(
0, maybeCircuitInfo->asReader());
return output;
}

View File

@@ -268,7 +268,7 @@ optimizer::Solution convertSolution(optimizer::CircuitSolution sol) {
/// Fill the compilation `feedback` from a `solution` returned by the optmizer.
template <typename Solution>
void fillFeedback(Solution solution, CompilationFeedback &feedback) {
void fillFeedback(Solution solution, ProgramCompilationFeedback &feedback) {
feedback.complexity = solution.complexity;
feedback.pError = solution.p_error;
feedback.globalPError =
@@ -314,7 +314,7 @@ llvm::Error checkPErrorSolution(Solution solution, optimizer::Config config) {
/// optimizer::Solution, and fill the `feedback`.
template <typename Solution>
llvm::Expected<optimizer::Solution>
toCompilerSolution(Solution solution, CompilationFeedback &feedback,
toCompilerSolution(Solution solution, ProgramCompilationFeedback &feedback,
optimizer::Config config) {
// display(descr, config, sol, naive_user, duration);
if (auto err = checkPErrorSolution(solution, config); err) {
@@ -334,9 +334,9 @@ optimizer::Solution emptySolution() {
return solution;
}
llvm::Expected<optimizer::Solution> getSolution(optimizer::Description &descr,
CompilationFeedback &feedback,
optimizer::Config config) {
llvm::Expected<optimizer::Solution>
getSolution(optimizer::Description &descr, ProgramCompilationFeedback &feedback,
optimizer::Config config) {
namespace chrono = std::chrono;
// auto start = chrono::high_resolution_clock::now();
auto naive_user =

View File

@@ -223,11 +223,6 @@ llvm::cl::opt<bool> dataflowParallelize(
llvm::cl::desc("Generate the program as a dataflow graph"),
llvm::cl::init(false));
llvm::cl::opt<std::string>
funcName("funcname",
llvm::cl::desc("Name of the function to compile, default 'main'"),
llvm::cl::init<std::string>(""));
llvm::cl::opt<bool>
chunkIntegers("chunk-integers",
llvm::cl::desc("Whether to decompose integer into chunks or "
@@ -379,12 +374,17 @@ llvm::cl::list<int64_t> largeIntegerCircuitBootstrap(
"(experimental) [level, baseLog]"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::opt<std::string> circuitEncodings(
"circuit-encodings",
llvm::cl::desc("Specify the input and output encodings of the circuit, "
llvm::cl::opt<std::string> programEncoding(
"program-encoding",
llvm::cl::desc("Specify the encodings to use for the program, "
"using the JSON representation."),
llvm::cl::init(std::string{}));
llvm::cl::opt<bool> skipProgramInfo(
"skip-program-info",
llvm::cl::desc("Skip generating the program info artefacts."),
llvm::cl::init(false));
} // namespace cmdline
namespace llvm {
@@ -424,6 +424,7 @@ cmdlineCompilationOptions() {
options.chunkIntegers = cmdline::chunkIntegers;
options.chunkSize = cmdline::chunkSize;
options.chunkWidth = cmdline::chunkWidth;
options.skipProgramInfo = cmdline::skipProgramInfo;
if (!cmdline::v0Constraint.empty()) {
if (cmdline::v0Constraint.size() != 2) {
@@ -435,10 +436,6 @@ cmdlineCompilationOptions() {
cmdline::v0Constraint[1], cmdline::v0Constraint[0]};
}
if (!cmdline::funcName.empty()) {
options.mainFuncName = cmdline::funcName;
}
// Convert tile sizes to `Optional`
if (!cmdline::fhelinalgTileSizes.empty())
options.fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes);
@@ -512,12 +509,12 @@ cmdlineCompilationOptions() {
llvm::inconvertibleErrorCode());
}
if (!cmdline::circuitEncodings.empty()) {
auto jsonString = cmdline::circuitEncodings.getValue();
auto encodings = Message<concreteprotocol::CircuitEncodingInfo>();
if (!cmdline::programEncoding.empty()) {
auto jsonString = cmdline::programEncoding.getValue();
auto encodings = Message<concreteprotocol::ProgramEncodingInfo>();
if (encodings.readJsonFromString(jsonString).has_failure()) {
return llvm::make_error<llvm::StringError>(
"Failed to parse the --circuit-encodings option",
"Failed to parse the --program-encoding option",
llvm::inconvertibleErrorCode());
}
options.encodings = encodings;
@@ -556,10 +553,9 @@ mlir::LogicalResult processInputBuffer(
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
std::string funcName = options.mainFuncName.value_or("");
mlir::concretelang::CompilerEngine ce{ccx};
ce.setCompilationOptions(std::move(options));
ce.setGenerateProgramInfo(!options.skipProgramInfo);
if (cmdline::passes.size() != 0) {
ce.setEnablePass([](mlir::Pass *pass) {

View File

@@ -1,15 +1,15 @@
// RUN: concretecompiler --action=dump-llvm-ir %s
// RUN: concretecompiler --action=dump-llvm-ir --optimizer-strategy=V0 --skip-program-info %s
// Just ensure that compile
// https://github.com/zama-ai/concrete-compiler-internal/issues/785
func.func @main(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<1x!FHE.eint<15>> {
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<15>, tensor<32768xi64>) -> !FHE.eint<15>
%6 = tensor.from_elements %1 : tensor<1x!FHE.eint<15>> // ERROR HERE line 4
return %6 : tensor<1x!FHE.eint<15>>
func.func @main(%arg0: !FHE.eint<5>, %cst: tensor<32xi64>) -> tensor<1x!FHE.eint<5>> {
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<5>
%6 = tensor.from_elements %1 : tensor<1x!FHE.eint<5>> // ERROR HERE line 4
return %6 : tensor<1x!FHE.eint<5>>
}
// Ensures that tensors of multiple elements can be constructed as well.
func.func @main2(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<2x!FHE.eint<15>> {
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<15>, tensor<32768xi64>) -> !FHE.eint<15>
%6 = tensor.from_elements %1, %arg0 : tensor<2x!FHE.eint<15>> // ERROR HERE line 4
return %6 : tensor<2x!FHE.eint<15>>
func.func @main2(%arg0: !FHE.eint<5>, %cst: tensor<32xi64>) -> tensor<2x!FHE.eint<5>> {
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<5>, tensor<32xi64>) -> !FHE.eint<5>
%6 = tensor.from_elements %1, %arg0 : tensor<2x!FHE.eint<5>> // ERROR HERE line 4
return %6 : tensor<2x!FHE.eint<5>>
}

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s
// RUN: concretecompiler --action=dump-tfhe --optimizer-strategy=V0 --force-encoding crt --skip-program-info %s
func.func @main(%arg0: tensor<32x!FHE.eint<8>>) -> tensor<16x!FHE.eint<8>>{
%0 = tensor.extract_slice %arg0[16] [16] [1] : tensor<32x!FHE.eint<8>> to tensor<16x!FHE.eint<8>>
return %0 : tensor<16x!FHE.eint<8>>

View File

@@ -1,12 +0,0 @@
// RUN: concretecompiler --action=dump-tfhe --force-encoding crt %s
func.func @main(%2: tensor<1x1x!FHE.eint<16>>) -> tensor<1x1x1x!FHE.eint<16>> {
%3 = tensor.expand_shape %2 [[0], [1, 2]] : tensor<1x1x!FHE.eint<16>> into tensor<1x1x1x!FHE.eint<16>>
return %3 : tensor<1x1x1x!FHE.eint<16>>
}
func.func @main2(%2: tensor<1x1x1x!FHE.eint<16>>) -> tensor<1x1x!FHE.eint<16>> {
%3 = tensor.collapse_shape %2 [[0], [1, 2]] : tensor<1x1x1x!FHE.eint<16>> into tensor<1x1x!FHE.eint<16>>
return %3 : tensor<1x1x!FHE.eint<16>>
}

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --action=dump-llvm-dialect --emit-gpu-ops %s 2>&1| FileCheck %s
// RUN: concretecompiler --action=dump-llvm-dialect --emit-gpu-ops --skip-program-info %s 2>&1| FileCheck %s
//CHECK: llvm.call @memref_keyswitch_lwe_cuda_u64
//CHECK: llvm.call @memref_bootstrap_lwe_cuda_u64

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --action=dump-parametrized-tfhe --optimizer-strategy=V0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 %s 2>&1| FileCheck %s
// RUN: concretecompiler --action=dump-parametrized-tfhe --optimizer-strategy=V0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 --skip-program-info %s 2>&1| FileCheck %s
//CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<sk<0,1,2048>>) -> !TFHE.glwe<sk<0,1,2048>> {
//CHECK-NEXT: %cst = arith.constant dense<0> : tensor<1024xi64>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_glwe(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64>
func.func @add_glwe(%arg0: !TFHE.glwe<sk[1]<1,2048>>, %arg1: !TFHE.glwe<sk[1]<1,2048>>) -> !TFHE.glwe<sk[1]<1,2048>> {

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i64 = arith.constant 1 : i64

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<5x8192xi64> {
// CHECK-NEXT: %0 = "Concrete.encode_lut_for_crt_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32} : (tensor<4xi64>) -> tensor<5x8192xi64>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// CHECK: func.func @main(%arg0: i64) -> tensor<5xi64> {
// CHECK-NEXT: %0 = "Concrete.encode_plaintext_with_crt_tensor"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// CHECK: func.func @keyswitch_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<568xi64> {
// CHECK-NEXT: %[[V0:.*]] = "Concrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 3 : i32, kskIndex = -1 : i32, level = 2 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 567 : i32} : (tensor<1025xi64>) -> tensor<568xi64>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i64 = arith.constant 1 : i64

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @neg_glwe(%arg0: tensor<1025xi64>) -> tensor<1025xi64>
func.func @neg_glwe(%arg0: !TFHE.glwe<sk[1]<1,1024>>) -> !TFHE.glwe<sk[1]<1,1024>> {

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i64 = arith.constant 1 : i64

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --split-input-file --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --split-input-file --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
//CHECK: func.func @tensor_collapse_shape(%[[A0:.*]]: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> {
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : tensor<2x3x4x5x6x1025xi64> into tensor<720x1025xi64>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --split-input-file --skip-program-info %s 2>&1| FileCheck %s
// CHECK: func.func @main(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>, %[[A2:.*]]: tensor<2049xi64>, %[[A3:.*]]: tensor<2049xi64>, %[[A4:.*]]: tensor<2049xi64>, %[[A5:.*]]: tensor<2049xi64>) -> tensor<6x2049xi64> {
// CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<6x2049xi64>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete --skip-program-info %s 2>&1| FileCheck %s
// CHECK: func.func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> {
// CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s
// RUN: concretecompiler --action=dump-fhe --optimizer-strategy=V0 --skip-program-info %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s
// RUN: concretecompiler --action=dump-fhe --optimizer-strategy=V0 --skip-program-info %s 2>&1| FileCheck %s
// CHECK: func.func @add_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-optimization --optimize-tfhe=false --action=dump-tfhe %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-optimization --optimize-tfhe=false --action=dump-tfhe --skip-program-info %s 2>&1| FileCheck %s
//CHECK: func.func @mul_cleartext_glwe_ciphertext_0(%[[A0:.*]]: !TFHE.glwe<sk[1]<527,1>>) -> !TFHE.glwe<sk[1]<527,1>> {
//CHECK: %c0_i64 = arith.constant 0 : i64

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe %s 2>&1| FileCheck %s
// RUN: concretecompiler --passes tfhe-optimization --action=dump-tfhe --skip-program-info %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !TFHE.glwe<sk[1]<527,1>>, %arg1: i64) -> !TFHE.glwe<sk[1]<527,1>>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --split-input-file --action=dump-batched-tfhe --batch-tfhe-ops %s 2>&1| FileCheck %s
// RUN: concretecompiler --split-input-file --action=dump-batched-tfhe --batch-tfhe-ops --skip-program-info %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @batch_continuous_slice_keyswitch
// CHECK: (%arg0: tensor<2x3x4x!TFHE.glwe<sk{{\[}}[[SK_IN:.*]]{{\]}}<1,2048>>>) -> tensor<2x3x4x!TFHE.glwe<sk{{\[}}[[SK_OUT:.*]]{{\]}}<1,750>>> {

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --split-input-file --action=dump-parametrized-tfhe --optimizer-strategy=dag-multi %s 2>&1| FileCheck %s
// RUN: concretecompiler --split-input-file --action=dump-parametrized-tfhe --optimizer-strategy=dag-multi --skip-program-info %s 2>&1| FileCheck %s
// CHECK: func.func @funconly_fwd(%arg0: !TFHE.glwe<sk[1]<12,1024>>) -> !TFHE.glwe<sk[1]<12,1024>> {
// CHECK-NEXT: return %arg0 : !TFHE.glwe<sk[1]<12,1024>>

View File

@@ -1,6 +1,6 @@
#include "../end_to_end_tests/end_to_end_test.h"
#include "concretelang/Common/Compat.h"
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include <concretelang/Runtime/DFRuntime.hpp>
#include <benchmark/benchmark.h>
@@ -23,7 +23,7 @@ using namespace concretelang::testlib;
/// Benchmark time of the compilation
static void BM_Compile(benchmark::State &state, EndToEndDesc description,
mlir::concretelang::CompilationOptions options) {
TestCircuit tc(options);
TestProgram tc(options);
for (auto _ : state) {
assert(tc.compile(description.program));
}
@@ -32,7 +32,7 @@ static void BM_Compile(benchmark::State &state, EndToEndDesc description,
/// Benchmark time of the key generation
static void BM_KeyGen(benchmark::State &state, EndToEndDesc description,
mlir::concretelang::CompilationOptions options) {
TestCircuit tc(options);
TestProgram tc(options);
assert(tc.compile(description.program));
for (auto _ : state) {
@@ -44,7 +44,7 @@ static void BM_KeyGen(benchmark::State &state, EndToEndDesc description,
static void BM_ExportArguments(benchmark::State &state,
EndToEndDesc description,
mlir::concretelang::CompilationOptions options) {
TestCircuit tc(options);
TestProgram tc(options);
assert(tc.compile(description.program));
assert(tc.generateKeyset());
@@ -68,7 +68,7 @@ static void BM_ExportArguments(benchmark::State &state,
/// Benchmark time of the program evaluation
static void BM_Evaluate(benchmark::State &state, EndToEndDesc description,
mlir::concretelang::CompilationOptions options) {
TestCircuit tc(options);
TestProgram tc(options);
assert(tc.compile(description.program));
assert(tc.generateKeyset());
auto clientCircuit = tc.getClientCircuit().value();
@@ -111,7 +111,6 @@ void registerEndToEndBenchmark(std::string suiteName,
int num_iterations = 0) {
auto optionsName = getOptionsName(options);
for (auto description : descriptions) {
options.mainFuncName = "main";
if (description.p_error) {
assert(std::isnan(options.optimizerConfig.global_p_error));
options.optimizerConfig.p_error = description.p_error.value();

View File

@@ -1,5 +1,5 @@
#include "concretelang/Common/Compat.h"
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "end_to_end_fixture/EndToEndFixture.h"
#include <concretelang/Runtime/DFRuntime.hpp>
#define BENCHMARK_HAS_CXX11

View File

@@ -4,7 +4,7 @@
#include <tuple>
#include <type_traits>
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "end_to_end_jit_test.h"
#include "tests_tools/GtestEnvironment.h"
std::vector<uint64_t> distributed_results;

View File

@@ -5,7 +5,7 @@
#include <tuple>
#include <type_traits>
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "end_to_end_jit_test.h"
#include "tests_tools/GtestEnvironment.h"
///////////////////////////////////////////////////////////////////////////////

View File

@@ -1,6 +1,6 @@
#include <gtest/gtest.h>
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "end_to_end_jit_test.h"
#include "tests_tools/GtestEnvironment.h"

View File

@@ -4,7 +4,7 @@
#include <tuple>
#include <type_traits>
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "end_to_end_jit_test.h"
#include "tests_tools/GtestEnvironment.h"

View File

@@ -1,6 +1,6 @@
#include <gtest/gtest.h>
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "end_to_end_jit_test.h"
#include "tests_tools/GtestEnvironment.h"
#include "tests_tools/assert.h"

View File

@@ -4,7 +4,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "end_to_end_jit_test.h"
#include "tests_tools/GtestEnvironment.h"
@@ -395,10 +395,10 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
}
TEST(CompileNotComposable, not_composable_1) {
mlir::concretelang::CompilationOptions options("main");
mlir::concretelang::CompilationOptions options;
options.optimizerConfig.composable = true;
options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MULTI;
TestCircuit circuit(options);
TestProgram circuit(options);
auto err = circuit.compile(R"XXX(
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
%cst_1 = arith.constant 1 : i4
@@ -411,11 +411,11 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
}
TEST(CompileNotComposable, not_composable_2) {
mlir::concretelang::CompilationOptions options("main");
mlir::concretelang::CompilationOptions options;
options.optimizerConfig.composable = true;
options.optimizerConfig.display = true;
options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MULTI;
TestCircuit circuit(options);
TestProgram circuit(options);
auto err = circuit.compile(R"XXX(
func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) {
%cst_1 = arith.constant 1 : i4
@@ -430,11 +430,11 @@ func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) {
}
TEST(CompileComposable, composable_supported_dag_mono) {
mlir::concretelang::CompilationOptions options("main");
mlir::concretelang::CompilationOptions options;
options.optimizerConfig.composable = true;
options.optimizerConfig.display = true;
options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MONO;
TestCircuit circuit(options);
TestProgram circuit(options);
auto err = circuit.compile(R"XXX(
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
%cst_1 = arith.constant 1 : i4
@@ -446,11 +446,11 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
}
TEST(CompileComposable, composable_supported_v0) {
mlir::concretelang::CompilationOptions options("main");
mlir::concretelang::CompilationOptions options;
options.optimizerConfig.composable = true;
options.optimizerConfig.display = true;
options.optimizerConfig.strategy = mlir::concretelang::optimizer::V0;
TestCircuit circuit(options);
TestProgram circuit(options);
auto err = circuit.compile(R"XXX(
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
%cst_1 = arith.constant 1 : i4
@@ -460,3 +460,39 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
)XXX");
assert(err.has_value());
}
TEST(CompileMultiFunctions, multi_functions_v0) {
mlir::concretelang::CompilationOptions options;
options.optimizerConfig.strategy = mlir::concretelang::optimizer::V0;
TestProgram circuit(options);
auto err = circuit.compile(R"XXX(
func.func @inc(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
%cst_1 = arith.constant 1 : i4
%1 = "FHE.add_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
return %1: !FHE.eint<3>
}
func.func @dec(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
%cst_1 = arith.constant 1 : i4
%1 = "FHE.sub_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3>
return %1: !FHE.eint<3>
}
)XXX");
assert(err.has_value());
assert(circuit.generateKeyset().has_value());
auto lambda_inc = [&](std::vector<concretelang::values::Value> args) {
return circuit.call(args, "inc")
.value()[0]
.template getTensor<uint64_t>()
.value()[0];
};
auto lambda_dec = [&](std::vector<concretelang::values::Value> args) {
return circuit.call(args, "dec")
.value()[0]
.template getTensor<uint64_t>()
.value()[0];
};
ASSERT_EQ(lambda_inc({Tensor<uint64_t>(1)}), (uint64_t)2);
ASSERT_EQ(lambda_inc({Tensor<uint64_t>(4)}), (uint64_t)5);
ASSERT_EQ(lambda_dec({Tensor<uint64_t>(1)}), (uint64_t)0);
ASSERT_EQ(lambda_dec({Tensor<uint64_t>(4)}), (uint64_t)3);
}

View File

@@ -5,7 +5,7 @@
#include "concretelang/Common/Error.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/V0Parameters.h"
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "cstdlib"
#include "end_to_end_test.h"
#include "globals.h"
@@ -14,7 +14,7 @@
using concretelang::error::Result;
using concretelang::error::StringError;
using concretelang::testlib::TestCircuit;
using concretelang::testlib::TestProgram;
llvm::StringRef DEFAULT_func = "main";
bool DEFAULT_useDefaultFHEConstraints = false;
@@ -30,7 +30,7 @@ bool DEFAULT_composable = false;
// Jit-compiles the function specified by `func` from `src` and
// returns the corresponding lambda. Any compilation errors are caught
// and reult in abnormal termination.
inline Result<TestCircuit> internalCheckedJit(
inline Result<TestProgram> internalCheckedJit(
llvm::StringRef src, llvm::StringRef func = DEFAULT_func,
bool useDefaultFHEConstraints = DEFAULT_useDefaultFHEConstraints,
bool dataflowParallelize = DEFAULT_dataflowParallelize,
@@ -42,8 +42,7 @@ inline Result<TestCircuit> internalCheckedJit(
unsigned int chunkWidth = DEFAULT_chunkWidth,
bool composable = DEFAULT_composable) {
auto options =
mlir::concretelang::CompilationOptions(std::string(func.data()));
auto options = mlir::concretelang::CompilationOptions();
options.optimizerConfig.global_p_error = global_p_error;
options.chunkIntegers = chunkedIntegers;
options.chunkSize = chunkSize;
@@ -70,10 +69,10 @@ inline Result<TestCircuit> internalCheckedJit(
}
std::vector<std::string> sources = {src.str()};
TestCircuit testCircuit(options);
OUTCOME_TRYV(testCircuit.compile({src.str()}));
OUTCOME_TRYV(testCircuit.generateKeyset());
return std::move(testCircuit);
TestProgram testProgram(options);
OUTCOME_TRYV(testProgram.compile({src.str()}));
OUTCOME_TRYV(testProgram.generateKeyset());
return std::move(testProgram);
}
// Wrapper around `internalCheckedJit` that causes

View File

@@ -7,14 +7,14 @@
#include "concretelang/Common/Values.h"
#include "concretelang/Support/CompilationFeedback.h"
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "end_to_end_fixture/EndToEndFixture.h"
#include "end_to_end_jit_test.h"
#include "tests_tools/GtestEnvironment.h"
#include "tests_tools/assert.h"
#include "tests_tools/keySetCache.h"
using concretelang::testlib::TestCircuit;
using concretelang::testlib::TestProgram;
using concretelang::values::Value;
/// @brief EndToEndTest is a template that allows testing for one program for a
@@ -35,7 +35,7 @@ public:
};
void SetUp() override {
TestCircuit tc(options.compilationOptions);
TestProgram tc(options.compilationOptions);
ASSERT_OUTCOME_HAS_VALUE(tc.compile({program}));
ASSERT_OUTCOME_HAS_VALUE(tc.generateKeyset());
testCircuit.emplace(std::move(tc));
@@ -122,7 +122,7 @@ private:
TestDescription desc;
std::optional<TestErrorRate> errorRate;
std::optional<mlir::concretelang::CompilerEngine::Library> library;
std::optional<TestCircuit> testCircuit;
std::optional<TestProgram> testCircuit;
EndToEndTestOptions options;
std::vector<Value> args;
};

View File

@@ -135,8 +135,7 @@ parseEndToEndCommandLine(int argc, char **argv) {
llvm::cl::ParseCommandLineOptions(argc, argv);
// Build compilation options
mlir::concretelang::CompilationOptions compilationOptions("main",
backend.getValue());
mlir::concretelang::CompilationOptions compilationOptions(backend.getValue());
if (loopParallelize.getValue() != -1)
compilationOptions.loopParallelize = loopParallelize.getValue();

View File

@@ -7,7 +7,8 @@ from concrete.compiler import (
LibrarySupport,
ClientSupport,
CompilationOptions,
CompilationFeedback,
ProgramCompilationFeedback,
CircuitCompilationFeedback,
)
@@ -23,31 +24,40 @@ def assert_result(result, expected_result):
assert np.all(result == expected_result)
def run(engine, args, compilation_result, keyset_cache):
def run(engine, args, compilation_result, keyset_cache, circuit_name="main"):
"""Execute engine on the given arguments.
Perform required loading, encryption, execution, and decryption."""
# Dev
compilation_feedback = engine.load_compilation_feedback(compilation_result)
assert isinstance(compilation_feedback, CompilationFeedback)
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
assert isinstance(compilation_feedback.complexity, float)
assert isinstance(compilation_feedback.p_error, float)
assert isinstance(compilation_feedback.global_p_error, float)
assert isinstance(compilation_feedback.total_secret_keys_size, int)
assert isinstance(compilation_feedback.total_bootstrap_keys_size, int)
assert isinstance(compilation_feedback.total_inputs_size, int)
assert isinstance(compilation_feedback.total_output_size, int)
assert isinstance(compilation_feedback.circuit_feedbacks, list)
circuit_feedback = next(
filter(lambda x: x.name == circuit_name, compilation_feedback.circuit_feedbacks)
)
assert isinstance(circuit_feedback, CircuitCompilationFeedback)
assert isinstance(circuit_feedback.total_inputs_size, int)
assert isinstance(circuit_feedback.total_output_size, int)
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args, circuit_name
)
# Server
server_lambda = engine.load_server_lambda(compilation_result, False)
server_lambda = engine.load_server_lambda(compilation_result, False, circuit_name)
evaluation_keys = key_set.get_evaluation_keys()
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
# Client
result = ClientSupport.decrypt_result(client_parameters, key_set, public_result)
result = ClientSupport.decrypt_result(
client_parameters, key_set, public_result, circuit_name
)
return result
@@ -57,11 +67,12 @@ def compile_run_assert(
args,
expected_result,
keyset_cache,
options=CompilationOptions.new("main"),
options=CompilationOptions.new(),
circuit_name="main",
):
"""Compile run and assert result."""
compilation_result = engine.compile(mlir_input, options)
result = run(engine, args, compilation_result, keyset_cache)
result = run(engine, args, compilation_result, keyset_cache, circuit_name)
assert_result(result, expected_result)
@@ -231,6 +242,33 @@ def test_lib_compilation_artifacts():
assert not os.path.exists(engine.get_shared_lib_path())
def test_multi_circuits(keyset_cache):
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
mlir_str = """
func.func @add(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
func.func @sub(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}
"""
args = (10, 3)
expected_add_result = 13
expected_sub_result = 7
engine = LibrarySupport.new("./py_test_multi_circuits")
options = CompilationOptions.new()
options.set_optimizer_strategy(OptimizerStrategy.V0)
compile_run_assert(
engine, mlir_str, args, expected_add_result, keyset_cache, options, "add"
)
compile_run_assert(
engine, mlir_str, args, expected_sub_result, keyset_cache, options, "sub"
)
def _test_lib_compile_and_run_with_options(keyset_cache, options):
mlir_input = """
func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
@@ -247,21 +285,21 @@ def _test_lib_compile_and_run_with_options(keyset_cache, options):
def test_lib_compile_and_run_p_error(keyset_cache):
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_p_error(0.00001)
options.set_display_optimizer_choice(True)
_test_lib_compile_and_run_with_options(keyset_cache, options)
def test_lib_compile_and_run_global_p_error(keyset_cache):
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_global_p_error(0.00001)
options.set_display_optimizer_choice(True)
_test_lib_compile_and_run_with_options(keyset_cache, options)
def test_lib_compile_and_run_security_level(keyset_cache):
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_security_level(80)
options.set_display_optimizer_choice(True)
_test_lib_compile_and_run_with_options(keyset_cache, options)
@@ -276,7 +314,7 @@ def test_compile_and_run_auto_parallelize(
):
artifact_dir = "./py_test_compile_and_run_auto_parallelize"
engine = LibrarySupport.new(artifact_dir)
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_auto_parallelize(True)
compile_run_assert(
engine, mlir_input, args, expected_result, keyset_cache, options=options
@@ -301,7 +339,7 @@ def test_compile_and_run_auto_parallelize(
# if no_parallel:
# artifact_dir = "./py_test_compile_dataflow_and_fail_run"
# engine = LibrarySupport.new(artifact_dir)
# options = CompilationOptions.new("main")
# options = CompilationOptions.new()
# options.set_auto_parallelize(True)
# with pytest.raises(
# RuntimeError,
@@ -334,7 +372,7 @@ def test_compile_and_run_loop_parallelize(
):
artifact_dir = "./py_test_compile_and_run_loop_parallelize"
engine = LibrarySupport.new(artifact_dir)
options = CompilationOptions.new("main")
options = CompilationOptions.new()
options.set_loop_parallelize(True)
compile_run_assert(
engine, mlir_input, args, expected_result, keyset_cache, options=options
@@ -365,29 +403,6 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args, keyset_cache):
compile_run_assert(engine, mlir_input, args, None, keyset_cache)
@pytest.mark.parametrize(
"mlir_input",
[
pytest.param(
"""
func.func @test(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7>
{
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!FHE.eint<7>>, tensor<4xi8>) -> !FHE.eint<7>
return %ret : !FHE.eint<7>
}
""",
id="not @main",
),
],
)
def test_compile_invalid(mlir_input):
artifact_dir = "./py_test_compile_invalid"
engine = LibrarySupport.new(artifact_dir)
with pytest.raises(RuntimeError, match=r"Function not found, name='main'"):
engine.compile(mlir_input)
def test_crt_decomposition_feedback():
mlir = """
@@ -401,11 +416,27 @@ func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
artifact_dir = "./py_test_crt_decomposition_feedback"
engine = LibrarySupport.new(artifact_dir)
compilation_result = engine.compile(mlir, options=CompilationOptions.new("main"))
compilation_result = engine.compile(mlir, options=CompilationOptions.new())
compilation_feedback = engine.load_compilation_feedback(compilation_result)
assert isinstance(compilation_feedback, CompilationFeedback)
assert compilation_feedback.crt_decompositions_of_outputs == [[7, 8, 9, 11, 13]]
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
assert isinstance(compilation_feedback.complexity, float)
assert isinstance(compilation_feedback.p_error, float)
assert isinstance(compilation_feedback.global_p_error, float)
assert isinstance(compilation_feedback.total_secret_keys_size, int)
assert isinstance(compilation_feedback.total_bootstrap_keys_size, int)
assert isinstance(compilation_feedback.circuit_feedbacks, list)
assert isinstance(
compilation_feedback.circuit_feedbacks[0], CircuitCompilationFeedback
)
assert isinstance(compilation_feedback.circuit_feedbacks[0].total_inputs_size, int)
assert isinstance(compilation_feedback.circuit_feedbacks[0].total_output_size, int)
assert isinstance(
compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs, list
)
assert compilation_feedback.circuit_feedbacks[0].crt_decompositions_of_outputs == [
[7, 8, 9, 11, 13]
]
@pytest.mark.parametrize(
@@ -450,10 +481,11 @@ def test_memory_usage(mlir: str, expected_memory_usage_per_loc: dict):
engine = LibrarySupport.new(artifact_dir)
compilation_result = engine.compile(mlir)
compilation_feedback = engine.load_compilation_feedback(compilation_result)
assert isinstance(compilation_feedback, CompilationFeedback)
assert isinstance(compilation_feedback, ProgramCompilationFeedback)
assert (
expected_memory_usage_per_loc == compilation_feedback.memory_usage_per_location
expected_memory_usage_per_loc
== compilation_feedback.circuit_feedbacks[0].memory_usage_per_location
)
shutil.rmtree(artifact_dir)

View File

@@ -49,7 +49,7 @@ def compile_run_assert(
mlir_input,
args_and_shape,
expected_result,
options=CompilationOptions.new("main"),
options=CompilationOptions.new(),
):
# compile with simulation
options.simulation(True)

View File

@@ -32,7 +32,10 @@ module {
compilation_result = support.compile(mlir)
client_parameters = support.load_client_parameters(compilation_result)
compilation_feedback = support.load_compilation_feedback(compilation_result)
program_compilation_feedback = support.load_compilation_feedback(
compilation_result
)
compilation_feedback = program_compilation_feedback.circuit("main")
pbs_count = compilation_feedback.count(
operations={

View File

@@ -12,7 +12,7 @@
#include "concretelang/Common/Error.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Encodings.h"
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "tests_tools/GtestEnvironment.h"
#include "tests_tools/assert.h"
@@ -23,17 +23,18 @@ testing::Environment *const dfr_env =
using namespace concretelang::testlib;
namespace encodings = mlir::concretelang::encodings;
Result<TestCircuit> setupTestCircuit(std::string source,
Result<TestProgram> setupTestProgram(std::string source,
std::string funcname = FUNCNAME) {
std::vector<std::string> sources = {source};
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
mlir::concretelang::CompilerEngine ce{ccx};
mlir::concretelang::CompilationOptions options(funcname);
mlir::concretelang::CompilationOptions options;
options.encodings = Message<concreteprotocol::CircuitEncodingInfo>();
auto inputs = options.encodings->asBuilder().initInputs(2);
auto outputs = options.encodings->asBuilder().initOutputs(1);
auto circuitEncoding = Message<concreteprotocol::CircuitEncodingInfo>();
auto inputs = circuitEncoding.asBuilder().initInputs(2);
auto outputs = circuitEncoding.asBuilder().initOutputs(1);
circuitEncoding.asBuilder().setName(funcname);
auto encodingInfo = Message<concreteprotocol::EncodingInfo>().asBuilder();
encodingInfo.initShape();
@@ -46,9 +47,12 @@ Result<TestCircuit> setupTestCircuit(std::string source,
inputs.setWithCaveats(1, encodingInfo);
outputs.setWithCaveats(0, encodingInfo);
options.encodings->asBuilder().setName("main");
options.encodings = Message<concreteprotocol::ProgramEncodingInfo>();
options.encodings->asBuilder().initCircuits(1).setWithCaveats(
0, circuitEncoding.asReader());
options.v0Parameter = {2, 10, 693, 4, 9, 7, 2, std::nullopt};
TestCircuit testCircuit(options);
TestProgram testCircuit(options);
OUTCOME_TRYV(testCircuit.compile({source}));
OUTCOME_TRYV(testCircuit.generateKeyset());
return std::move(testCircuit);
@@ -67,7 +71,7 @@ func.func @main(
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
uint64_t a = 5;
uint64_t b = 5;
auto res = circuit.call({Tensor<uint64_t>(a), Tensor<uint64_t>(b)});

View File

@@ -10,7 +10,7 @@
#include "concretelang/Common/Error.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "tests_tools/GtestEnvironment.h"
#include "tests_tools/assert.h"
@@ -20,15 +20,15 @@ using namespace concretelang::testlib;
testing::Environment *const dfr_env =
testing::AddGlobalTestEnvironment(new DFREnvironment);
Result<TestCircuit> setupTestCircuit(std::string source,
Result<TestProgram> setupTestProgram(std::string source,
std::string funcname = FUNCNAME) {
mlir::concretelang::CompilationOptions options(funcname);
mlir::concretelang::CompilationOptions options;
#ifdef CONCRETELANG_CUDA_SUPPORT
options.emitGPUOps = true;
options.emitSDFGOps = true;
#endif
options.batchTFHEOps = true;
TestCircuit testCircuit(options);
TestProgram testCircuit(options);
OUTCOME_TRYV(testCircuit.compile({source}));
OUTCOME_TRYV(testCircuit.generateKeyset());
return std::move(testCircuit);
@@ -41,7 +41,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
return %1: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_7bits())
for (auto b : values_7bits()) {
if (a > b) {
@@ -61,7 +61,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
return %1: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_7bits())
for (auto b : values_7bits()) {
if (a > b) {
@@ -81,7 +81,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> {
return %1: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_3bits())
for (auto b : values_3bits()) {
if (a > b) {
@@ -101,7 +101,7 @@ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
return %1: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_7bits()) {
auto res = circuit.call({Tensor<uint64_t>(a)});
ASSERT_TRUE(res.has_value());
@@ -119,7 +119,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %
return %3: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_3bits()) {
for (auto b : values_3bits()) {
auto res = circuit.call({
@@ -143,7 +143,7 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
return %1: !FHE.eint<3>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_3bits()) {
auto res = circuit.call({Tensor<uint64_t>(a)});
ASSERT_TRUE(res.has_value());
@@ -165,7 +165,7 @@ func.func @main(%arg0: !FHE.eint<4>) -> !FHE.eint<4> {
return %6: !FHE.eint<4>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_3bits()) {
auto res = circuit.call({Tensor<uint64_t>(a)});
ASSERT_TRUE(res.has_value());
@@ -182,7 +182,7 @@ TEST(SDFG_unit_tests, tlu_batched) {
return %res : tensor<3x3x!FHE.eint<3>>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
auto t = Tensor<uint64_t>({0, 1, 2, 3, 0, 1, 2, 3, 0}, {3, 3});
auto expected = Tensor<uint64_t>({1, 3, 5, 7, 1, 3, 5, 7, 1}, {3, 3});
auto res = circuit.call({t});
@@ -201,7 +201,7 @@ TEST(SDFG_unit_tests, batched_tree) {
return %res : tensor<3x3x!FHE.eint<4>>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
auto t = Tensor<uint64_t>({0, 1, 2, 3, 0, 1, 2, 3, 0}, {3, 3});
auto a1 = Tensor<uint8_t>({0, 1, 0, 0, 1, 0, 0, 1, 0}, {3, 3});
auto a2 = Tensor<uint8_t>({1, 0, 1, 1, 0, 1, 1, 0, 1}, {3, 3});
@@ -226,7 +226,7 @@ TEST(SDFG_unit_tests, batched_tree_mapped_tlu) {
return %res : tensor<3x3x!FHE.eint<4>>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
auto t = Tensor<uint64_t>({0, 1, 2, 3, 0, 1, 2, 3, 0}, {3, 3});
auto a1 = Tensor<uint8_t>({0, 1, 0, 0, 1, 0, 0, 1, 0}, {3, 3});
auto a2 = Tensor<uint8_t>({1, 0, 1, 1, 0, 1, 1, 0, 1}, {3, 3});

View File

@@ -8,7 +8,7 @@
#include "concretelang/Common/Error.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/TestLib/TestCircuit.h"
#include "concretelang/TestLib/TestProgram.h"
#include "tests_tools/GtestEnvironment.h"
#include "tests_tools/assert.h"
@@ -18,17 +18,17 @@ using namespace concretelang::testlib;
testing::Environment *const dfr_env =
testing::AddGlobalTestEnvironment(new DFREnvironment);
Result<TestCircuit> setupTestCircuit(std::string source,
Result<TestProgram> setupTestProgram(std::string source,
std::string funcname = FUNCNAME) {
std::vector<std::string> sources = {source};
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
mlir::concretelang::CompilerEngine ce{ccx};
mlir::concretelang::CompilationOptions options(funcname);
mlir::concretelang::CompilationOptions options;
#ifdef CONCRETELANG_DATAFLOW_TESTING_ENABLED
options.dataflowParallelize = true;
#endif
TestCircuit testCircuit(options);
TestProgram testCircuit(options);
OUTCOME_TRYV(testCircuit.compile({source}));
OUTCOME_TRYV(testCircuit.generateKeyset());
return std::move(testCircuit);
@@ -67,7 +67,7 @@ func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
return %arg0: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_7bits()) {
auto res = circuit.call({Tensor<uint64_t>(a)});
ASSERT_TRUE(res.has_value());
@@ -82,7 +82,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
return %arg0: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_7bits())
for (auto b : values_7bits()) {
if (a > b) {
@@ -102,7 +102,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
return %1: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_7bits())
for (auto b : values_7bits()) {
if (a > b) {
@@ -122,7 +122,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
return %1: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
auto res = circuit.call({Tensor<uint64_t>(1)});
ASSERT_FALSE(res.has_value());
}
@@ -134,7 +134,7 @@ func.func @main(%arg0: !FHE.eint<7>) -> tensor<1x!FHE.eint<7>> {
return %1: tensor<1x!FHE.eint<7>>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_7bits()) {
auto res = circuit.call({Tensor<uint64_t>(a)});
EXPECT_TRUE(res);
@@ -150,7 +150,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> tensor<2x!FHE.eint<
return %1: tensor<2x!FHE.eint<7>>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_7bits()) {
auto res = circuit.call({Tensor<uint64_t>(a), Tensor<uint64_t>(a + 1)});
EXPECT_TRUE(res);
@@ -168,7 +168,7 @@ func.func @main(%arg0: tensor<1x!FHE.eint<7>>) -> !FHE.eint<7> {
return %1: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (uint8_t a : values_7bits()) {
auto ta = Tensor<uint64_t>({a}, {1});
auto res = circuit.call({ta});
@@ -184,7 +184,7 @@ func.func @main(%arg0: tensor<3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> {
return %arg0: tensor<3x!FHE.eint<7>>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
auto ta = Tensor<uint64_t>({1, 2, 3}, {3});
auto res = circuit.call({ta});
ASSERT_TRUE(res);
@@ -202,7 +202,7 @@ func.func @main(%arg0: tensor<3x!FHE.eint<7>>, %arg1: tensor<3x!FHE.eint<7>>) ->
return %3: !FHE.eint<7>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
auto ta = Tensor<uint64_t>({1, 2, 3}, {3});
auto tb = Tensor<uint64_t>({5, 7, 9}, {3});
auto res = circuit.call({ta, tb});
@@ -219,7 +219,7 @@ func.func @main(%arg0: tensor<2x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> {
return %arg0: tensor<2x3x!FHE.eint<7>>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
auto ta = Tensor<uint64_t>({1, 2, 3, 4, 5, 6}, {2, 3});
auto res = circuit.call({ta});
ASSERT_TRUE(res);
@@ -233,7 +233,7 @@ func.func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>) -> tensor<2x3x1x!FHE.eint<7>>
return %arg0: tensor<2x3x1x!FHE.eint<7>>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
auto ta = Tensor<uint64_t>({1, 2, 3, 4, 5, 6}, {2, 3, 1});
auto res = circuit.call({ta});
ASSERT_TRUE(res);
@@ -248,7 +248,7 @@ func.func @main(%arg0: tensor<2x3x1x!FHE.eint<7>>, %arg1: tensor<2x3x1x!FHE.eint
return %1: tensor<2x3x1x!FHE.eint<7>>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
auto ta = Tensor<uint64_t>({1, 2, 3, 4, 5, 6}, {2, 3, 1});
auto res = circuit.call({ta, ta});
ASSERT_TRUE(res);
@@ -312,7 +312,7 @@ func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<3>) -> !FHE.eint<6> {
return %a_plus_b: !FHE.eint<6>
}
)";
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestCircuit(source));
ASSERT_ASSIGN_OUTCOME_VALUE(circuit, setupTestProgram(source));
for (auto a : values_6bits())
for (auto b : values_3bits()) {
auto res = circuit.call({Tensor<uint64_t>(a), Tensor<uint64_t>(b)});

View File

@@ -599,6 +599,10 @@ impl OperationDag {
self.0.dump()
}
fn concat(&mut self, other: &Self) {
self.0.concat(&other.0);
}
fn tag_operator_as_output(&mut self, op: ffi::OperatorIndex) {
self.0.tag_operator_as_output(op.into());
}
@@ -741,6 +745,8 @@ mod ffi {
fn dump(self: &OperationDag) -> String;
fn concat(self: &mut OperationDag, other: &OperationDag);
#[namespace = "concrete_optimizer::dag"]
fn dump(self: &CircuitSolution) -> String;

View File

@@ -976,6 +976,7 @@ struct OperationDag final : public ::rust::Opaque {
::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept;
::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept;
::rust::String dump() const noexcept;
void concat(::concrete_optimizer::OperationDag const &other) noexcept;
void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept;
::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept;
~OperationDag() = delete;
@@ -1289,6 +1290,8 @@ extern "C" {
void concrete_optimizer$cxxbridge1$OperationDag$optimize(::concrete_optimizer::OperationDag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::DagSolution *return$) noexcept;
void concrete_optimizer$cxxbridge1$OperationDag$dump(::concrete_optimizer::OperationDag const &self, ::rust::String *return$) noexcept;
void concrete_optimizer$cxxbridge1$OperationDag$concat(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::OperationDag const &other) noexcept;
} // extern "C"
namespace dag {
@@ -1390,6 +1393,10 @@ namespace dag {
return ::std::move(return$.value);
}
void OperationDag::concat(::concrete_optimizer::OperationDag const &other) noexcept {
concrete_optimizer$cxxbridge1$OperationDag$concat(*this, other);
}
namespace dag {
::rust::String CircuitSolution::dump() const noexcept {
::rust::MaybeUninit<::rust::String> return$;

View File

@@ -957,6 +957,7 @@ struct OperationDag final : public ::rust::Opaque {
::concrete_optimizer::dag::OperatorIndex add_unsafe_cast_op(::concrete_optimizer::dag::OperatorIndex input, ::std::uint8_t rounded_precision) noexcept;
::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept;
::rust::String dump() const noexcept;
void concat(::concrete_optimizer::OperationDag const &other) noexcept;
void tag_operator_as_output(::concrete_optimizer::dag::OperatorIndex op) noexcept;
::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept;
~OperationDag() = delete;

View File

@@ -252,6 +252,29 @@ impl OperationDag {
self.add_lut(rounded, table, out_precision)
}
/// Concatenates two dags into a single one (with two disconnected clusters).
pub fn concat(&mut self, other: &Self) {
let length = self.len();
self.operators.extend(other.operators.iter().cloned());
self.out_precisions.extend(other.out_precisions.iter());
self.out_shapes.extend(other.out_shapes.iter().cloned());
self.output_tags.extend(other.output_tags.iter());
self.operators[length..]
.iter_mut()
.for_each(|node| match node {
Operator::Lut { ref mut input, .. }
| Operator::UnsafeCast { ref mut input, .. }
| Operator::Round { ref mut input, .. } => {
input.i += length;
}
Operator::Dot { ref mut inputs, .. }
| Operator::LevelledOp { ref mut inputs, .. } => {
inputs.iter_mut().for_each(|inp| inp.i += length);
}
_ => (),
});
}
/// Returns an iterator over input nodes indices.
pub(crate) fn get_input_index_iter(&self) -> impl Iterator<Item = usize> + '_ {
self.operators
@@ -315,6 +338,7 @@ impl OperationDag {
DotKind::Broadcast { shape } => shape,
DotKind::Unsupported { .. } => {
let weights_shape = &weights.shape;
println!();
println!();
println!("Error diagnostic on dot operation:");
@@ -347,6 +371,33 @@ mod tests {
use super::*;
use crate::dag::operator::Shape;
#[test]
fn graph_concat() {
let mut graph1 = OperationDag::new();
let a = graph1.add_input(1, Shape::number());
let b = graph1.add_input(1, Shape::number());
let c = graph1.add_dot([a, b], [1, 1]);
let _d = graph1.add_lut(c, FunctionTable::UNKWOWN, 1);
let mut graph2 = OperationDag::new();
let a = graph2.add_input(2, Shape::number());
let b = graph2.add_input(2, Shape::number());
let c = graph2.add_dot([a, b], [2, 2]);
let _d = graph2.add_lut(c, FunctionTable::UNKWOWN, 2);
graph1.concat(&graph2);
let mut graph3 = OperationDag::new();
let a = graph3.add_input(1, Shape::number());
let b = graph3.add_input(1, Shape::number());
let c = graph3.add_dot([a, b], [1, 1]);
let _d = graph3.add_lut(c, FunctionTable::UNKWOWN, 1);
let a = graph3.add_input(2, Shape::number());
let b = graph3.add_input(2, Shape::number());
let c = graph3.add_dot([a, b], [2, 2]);
let _d = graph3.add_lut(c, FunctionTable::UNKWOWN, 2);
assert_eq!(graph1, graph3);
}
#[test]
fn graph_creation() {
let mut graph = OperationDag::new();