feat(compiler): First draft or compilation feedback

This commit is contained in:
Quentin Bourgerie
2022-09-09 23:11:49 +02:00
parent 698bd28104
commit f4673e8276
17 changed files with 370 additions and 37 deletions

View File

@@ -18,6 +18,19 @@
#include <llvm/Support/JSON.h>
namespace concretelang {
inline size_t bitWidthAsWord(size_t exactBitWidth) {
if (exactBitWidth <= 8)
return 8;
if (exactBitWidth <= 16)
return 16;
if (exactBitWidth <= 32)
return 32;
if (exactBitWidth <= 64)
return 64;
assert(false && "Bit witdh > 64 not supported");
}
namespace clientlib {
using concretelang::error::StringError;
@@ -44,6 +57,7 @@ struct LweSecretKeyParam {
void hash(size_t &seed);
inline uint64_t lweDimension() { return dimension; }
inline uint64_t lweSize() { return dimension + 1; }
inline uint64_t byteSize() { return lweSize() * 8; }
};
static bool operator==(const LweSecretKeyParam &lhs,
const LweSecretKeyParam &rhs) {
@@ -60,6 +74,11 @@ struct BootstrapKeyParam {
Variance variance;
void hash(size_t &seed);
uint64_t byteSize(uint64_t inputLweSize, uint64_t outputLweSize) {
return inputLweSize * level * (glweDimension + 1) * (glweDimension + 1) *
outputLweSize * 8;
}
};
static inline bool operator==(const BootstrapKeyParam &lhs,
const BootstrapKeyParam &rhs) {
@@ -78,6 +97,10 @@ struct KeyswitchKeyParam {
Variance variance;
void hash(size_t &seed);
size_t byteSize(size_t inputLweSize, size_t outputLweSize) {
return level * inputLweSize * outputLweSize * 8;
}
};
static inline bool operator==(const KeyswitchKeyParam &lhs,
const KeyswitchKeyParam &rhs) {
@@ -125,6 +148,19 @@ struct CircuitGate {
CircuitGateShape shape;
bool isEncrypted() { return encryption.hasValue(); }
/// byteSize returns the size in bytes for this gate.
size_t byteSize(std::map<LweSecretKeyID, LweSecretKeyParam> secretKeys) {
auto width = shape.width;
auto numElts = shape.size == 0 ? 1 : shape.size;
if (isEncrypted()) {
auto skParam = secretKeys.find(encryption->secretKeyID);
assert(skParam != secretKeys.end());
return 8 * skParam->second.lweSize() * numElts;
}
width = bitWidthAsWord(width) / 8;
return width * numElts;
}
};
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape;

View File

@@ -23,18 +23,6 @@ using concretelang::error::StringError;
class PublicArguments;
inline size_t bitWidthAsWord(size_t exactBitWidth) {
if (exactBitWidth <= 8)
return 8;
if (exactBitWidth <= 16)
return 16;
if (exactBitWidth <= 32)
return 32;
if (exactBitWidth <= 64)
return 64;
assert(false && "Bit witdh > 64 not supported");
}
/// Temporary object used to hold and encrypt parameters before calling a
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
/// Otherwise convert it to a PublicArguments and use

View File

@@ -0,0 +1,61 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_COMPILATIONFEEDBACK_H_
#define CONCRETELANG_SUPPORT_COMPILATIONFEEDBACK_H_
#include <cstddef>
#include <vector>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "llvm/Support/Error.h"
namespace mlir {
namespace concretelang {
using StringError = ::concretelang::error::StringError;
struct CompilationFeedback {
double complexity;
/// @brief the total number of bytes of secret keys
size_t totalSecretKeysSize;
/// @brief the total number of bytes of bootstrap keys
size_t totalBootstrapKeysSize;
/// @brief the total number of bytes of keyswitch keys
size_t totalKeyswitchKeysSize;
/// @brief the total number of bytes of inputs
size_t totalInputsSize;
/// @brief the total number of bytes of outputs
size_t totalOutputsSize;
/// Fill the sizes from the client parameters.
void
fillFromClientParameters(::concretelang::clientlib::ClientParameters params);
/// Load the compilation feedback from a path
static outcome::checked<CompilationFeedback, StringError>
load(std::string path);
};
llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &);
bool fromJSON(const llvm::json::Value,
mlir::concretelang::CompilationFeedback &, llvm::json::Path);
} // namespace concretelang
} // namespace mlir
static inline llvm::raw_ostream &
operator<<(llvm::raw_string_ostream &OS,
mlir::concretelang::CompilationFeedback cp) {
return OS << llvm::formatv("{0:2}", toJSON(cp));
}
#endif

View File

@@ -83,6 +83,7 @@ public:
llvm::Optional<mlir::OwningOpRef<mlir::ModuleOp>> mlirModuleRef;
llvm::Optional<mlir::concretelang::ClientParameters> clientParameters;
llvm::Optional<CompilationFeedback> feedback;
std::unique_ptr<llvm::Module> llvmModule;
llvm::Optional<mlir::concretelang::V0FHEContext> fheContext;
@@ -94,6 +95,8 @@ public:
std::string outputDirPath;
std::vector<std::string> objectsPath;
std::vector<mlir::concretelang::ClientParameters> clientParametersList;
std::vector<mlir::concretelang::CompilationFeedback>
compilationFeedbackList;
/// Path to the runtime library. Will be linked to the output library if set
std::string runtimeLibraryPath;
bool cleanUp;
@@ -110,7 +113,8 @@ public:
llvm::Expected<std::string> addCompilation(CompilationResult &compilation);
/// Emit the library artifacts with the previously added compilation result
llvm::Error emitArtifacts(bool sharedLib, bool staticLib,
bool clientParameters, bool cppHeader);
bool clientParameters, bool compilationFeedback,
bool cppHeader);
/// After a shared library has been emitted, its path is here
std::string sharedLibraryPath;
/// After a static library has been emitted, its path is here
@@ -122,9 +126,12 @@ public:
/// Returns the path of the static library
static std::string getStaticLibraryPath(std::string outputDirPath);
/// Returns the path of the static library
/// Returns the path of the client parameters
static std::string getClientParametersPath(std::string outputDirPath);
/// Returns the path of the compilation feedback
static std::string getCompilationFeedbackPath(std::string outputDirPath);
// For advanced use
const static std::string OBJECT_EXT, LINKER, LINKER_SHARED_OPT, AR,
AR_STATIC_OPT, DOT_STATIC_LIB_EXT, DOT_SHARED_LIB_EXT;
@@ -141,6 +148,8 @@ public:
llvm::Expected<std::string> emitShared();
/// Emit a json ClientParameters corresponding to library content
llvm::Expected<std::string> emitClientParametersJSON();
/// Emit a json CompilationFeedback corresponding to library content
llvm::Expected<std::string> emitCompilationFeedbackJSON();
/// Emit a client header file for this corresponding to library content
llvm::Expected<std::string> emitCppHeader();
};
@@ -211,6 +220,7 @@ public:
compile(std::vector<std::string> inputs, std::string outputDirPath,
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
bool generateStaticLib = true, bool generateClientParameters = true,
bool generateCompilationFeedback = true,
bool generateCppHeader = true);
/// Compile and emit artifact to the given outputDirPath from an LLVM source
@@ -219,6 +229,7 @@ public:
compile(llvm::SourceMgr &sm, std::string outputDirPath,
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
bool generateStaticLib = true, bool generateClientParameters = true,
bool generateCompilationFeedback = true,
bool generateCppHeader = true);
void setCompilationOptions(CompilationOptions &options) {

View File

@@ -25,6 +25,7 @@ namespace clientlib = ::concretelang::clientlib;
struct JitCompilationResult {
std::shared_ptr<concretelang::JITLambda> lambda;
clientlib::ClientParameters clientParameters;
CompilationFeedback feedback;
};
/// JITSupport is the instantiated LambdaSupport for the Jit Compilation.
@@ -49,6 +50,11 @@ public:
return result.clientParameters;
}
llvm::Expected<CompilationFeedback>
loadCompilationFeedback(JitCompilationResult &result) override {
return result.feedback;
}
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
serverCall(std::shared_ptr<concretelang::JITLambda> lambda,
clientlib::PublicArguments &args,

View File

@@ -273,6 +273,10 @@ public:
llvm::Expected<clientlib::ClientParameters> virtual loadClientParameters(
CompilationResult &result) = 0;
/// Load the compilation feedback from the compilation result.
llvm::Expected<CompilationFeedback> virtual loadCompilationFeedback(
CompilationResult &result) = 0;
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>> virtual serverCall(
Lambda lambda, clientlib::PublicArguments &args,

View File

@@ -101,6 +101,17 @@ public:
return *param;
}
llvm::Expected<CompilationFeedback>
loadCompilationFeedback(LibraryCompilationResult &result) override {
auto path = CompilerEngine::Library::getCompilationFeedbackPath(
result.outputDirPath);
auto feedback = CompilationFeedback::load(path);
if (feedback.has_error()) {
return StreamStringError(feedback.error().mesg);
}
return feedback.value();
}
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
serverCall(serverlib::ServerLambda lambda, clientlib::PublicArguments &args,

View File

@@ -10,6 +10,7 @@
#include "concrete-optimizer.hpp"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Support/CompilationFeedback.h"
namespace mlir {
namespace concretelang {
@@ -49,6 +50,7 @@ struct Description {
} // namespace optimizer
llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
CompilationFeedback &feedback,
optimizer::Config optimizerConfig);
} // namespace concretelang
} // namespace mlir

View File

@@ -72,6 +72,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
options.optimizerConfig.global_p_error = global_p_error;
});
pybind11::class_<mlir::concretelang::CompilationFeedback>(
m, "CompilationFeedback");
pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JITCompilationResult");
pybind11::class_<mlir::concretelang::JITLambda,
@@ -91,6 +94,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
mlir::concretelang::JitCompilationResult &result) {
return jit_load_client_parameters(support, result);
})
.def("load_compilation_feedback",
[](JITSupport_C &support,
mlir::concretelang::JitCompilationResult &result) {
return jit_load_client_parameters(support, result);
})
.def(
"load_server_lambda",
[](JITSupport_C &support,
@@ -135,6 +143,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
mlir::concretelang::LibraryCompilationResult &result) {
return library_load_client_parameters(support, result);
})
.def("load_compilation_feedback",
[](LibrarySupport_C &support,
mlir::concretelang::LibraryCompilationResult &result) {
return library_load_compilation_feedback(support, result);
})
.def(
"load_server_lambda",
[](LibrarySupport_C &support,

View File

@@ -1,6 +1,7 @@
add_mlir_library(ConcretelangSupport
Pipeline.cpp
Jit.cpp
CompilationFeedback.cpp
CompilerEngine.cpp
JITSupport.cpp
LambdaArgument.cpp

View File

@@ -0,0 +1,140 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <fstream>
#include "boost/outcome.h"
#include "concretelang/Support/CompilationFeedback.h"
namespace mlir {
namespace concretelang {
void CompilationFeedback::fillFromClientParameters(
::concretelang::clientlib::ClientParameters params) {
// Compute the size of secret keys
totalSecretKeysSize = 0;
for (auto sk : params.secretKeys) {
totalSecretKeysSize += sk.second.byteSize();
}
// Compute the boostrap keys size
totalBootstrapKeysSize = 0;
for (auto bsk : params.bootstrapKeys) {
auto bskParam = bsk.second;
auto inputKey = params.secretKeys.find(bskParam.inputSecretKeyID);
assert(inputKey != params.secretKeys.end());
auto outputKey = params.secretKeys.find(bskParam.outputSecretKeyID);
assert(outputKey != params.secretKeys.end());
totalBootstrapKeysSize += bskParam.byteSize(inputKey->second.lweSize(),
outputKey->second.lweSize());
}
// Compute the keyswitch keys size
totalKeyswitchKeysSize = 0;
for (auto ksk : params.keyswitchKeys) {
auto kskParam = ksk.second;
auto inputKey = params.secretKeys.find(kskParam.inputSecretKeyID);
assert(inputKey != params.secretKeys.end());
auto outputKey = params.secretKeys.find(kskParam.outputSecretKeyID);
assert(outputKey != params.secretKeys.end());
totalKeyswitchKeysSize += kskParam.byteSize(inputKey->second.lweSize(),
outputKey->second.lweSize());
}
// Compute the size of inputs
totalInputsSize = 0;
for (auto gate : params.inputs) {
totalInputsSize += gate.byteSize(params.secretKeys);
}
// Compute the size of outputs
totalOutputsSize = 0;
for (auto gate : params.outputs) {
totalOutputsSize += gate.byteSize(params.secretKeys);
}
}
outcome::checked<CompilationFeedback, StringError>
CompilationFeedback::load(std::string jsonPath) {
std::ifstream file(jsonPath);
std::string content((std::istreambuf_iterator<char>(file)),
(std::istreambuf_iterator<char>()));
if (file.fail()) {
return StringError("Cannot read file: ") << jsonPath;
}
auto expectedCompFeedback = llvm::json::parse<CompilationFeedback>(content);
if (auto err = expectedCompFeedback.takeError()) {
return StringError("Cannot open client parameters: ")
<< llvm::toString(std::move(err)) << "\n"
<< content << "\n";
}
return expectedCompFeedback.get();
}
llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) {
llvm::json::Object object{
{"complexity", v.complexity},
{"totalSecretKeysSize", v.totalSecretKeysSize},
{"totalBootstrapKeysSize", v.totalBootstrapKeysSize},
{"totalKeyswitchKeysSize", v.totalKeyswitchKeysSize},
{"totalInputsSize", v.totalInputsSize},
{"totalOutputsSize", v.totalOutputsSize},
};
return object;
}
bool fromJSON(const llvm::json::Value j,
mlir::concretelang::CompilationFeedback &v, llvm::json::Path p) {
auto obj = j.getAsObject();
if (obj == nullptr) {
p.report("should be an object");
return false;
}
auto complexity = obj->getInteger("complexity");
if (!complexity.hasValue()) {
p.report("missing size field");
return false;
}
v.complexity = *complexity;
auto totalSecretKeysSize = obj->getInteger("totalSecretKeysSize");
if (!totalSecretKeysSize.hasValue()) {
p.report("missing totalSecretKeysSize field");
return false;
}
v.totalSecretKeysSize = *totalSecretKeysSize;
auto totalBootstrapKeysSize = obj->getInteger("totalBootstrapKeysSize");
if (!totalBootstrapKeysSize.hasValue()) {
p.report("missing totalBootstrapKeysSize field");
return false;
}
v.totalBootstrapKeysSize = *totalBootstrapKeysSize;
auto totalKeyswitchKeysSize = obj->getInteger("totalKeyswitchKeysSize");
if (!totalKeyswitchKeysSize.hasValue()) {
p.report("missing totalKeyswitchKeysSize field");
return false;
}
v.totalKeyswitchKeysSize = *totalKeyswitchKeysSize;
auto totalInputsSize = obj->getInteger("totalInputsSize");
if (!totalInputsSize.hasValue()) {
p.report("missing totalInputsSize field");
return false;
}
v.totalInputsSize = *totalInputsSize;
auto totalOutputsSize = obj->getInteger("totalOutputsSize");
if (!totalOutputsSize.hasValue()) {
p.report("missing totalOutputsSize field");
return false;
}
v.totalOutputsSize = *totalOutputsSize;
return true;
}
} // namespace concretelang
} // namespace mlir

View File

@@ -192,13 +192,15 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
if (!descr.get().hasValue()) {
return llvm::Error::success();
}
auto v0Params =
getParameter(descr.get().value(), compilerOptions.optimizerConfig);
CompilationFeedback feedback;
auto v0Params = getParameter(descr.get().value(), feedback,
compilerOptions.optimizerConfig);
if (auto err = v0Params.takeError()) {
return err;
}
res.fheContext.emplace(mlir::concretelang::V0FHEContext{
descr.get().value().constraint, v0Params.get()});
res.feedback.emplace(feedback);
}
return llvm::Error::success();
@@ -347,6 +349,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return clientParametersOrErr.takeError();
res.clientParameters = clientParametersOrErr.get();
res.feedback->fillFromClientParameters(*res.clientParameters);
}
}
@@ -440,12 +443,11 @@ CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
return this->compile(sm, target, lib);
}
llvm::Expected<CompilerEngine::Library>
CompilerEngine::compile(std::vector<std::string> inputs,
std::string outputDirPath,
std::string runtimeLibraryPath, bool generateSharedLib,
bool generateStaticLib, bool generateClientParameters,
bool generateCppHeader) {
llvm::Expected<CompilerEngine::Library> CompilerEngine::compile(
std::vector<std::string> inputs, std::string outputDirPath,
std::string runtimeLibraryPath, bool generateSharedLib,
bool generateStaticLib, bool generateClientParameters,
bool generateCompilationFeedback, bool generateCppHeader) {
using Library = mlir::concretelang::CompilerEngine::Library;
auto outputLib = std::make_shared<Library>(outputDirPath, runtimeLibraryPath);
auto target = CompilerEngine::Target::LIBRARY;
@@ -456,9 +458,9 @@ CompilerEngine::compile(std::vector<std::string> inputs,
<< llvm::toString(compilation.takeError());
}
}
if (auto err = outputLib->emitArtifacts(generateSharedLib, generateStaticLib,
generateClientParameters,
generateCppHeader)) {
if (auto err = outputLib->emitArtifacts(
generateSharedLib, generateStaticLib, generateClientParameters,
generateCompilationFeedback, generateCppHeader)) {
return StreamStringError("Can't emit artifacts: ")
<< llvm::toString(std::move(err));
}
@@ -469,6 +471,7 @@ llvm::Expected<CompilerEngine::Library>
CompilerEngine::compile(llvm::SourceMgr &sm, std::string outputDirPath,
std::string runtimeLibraryPath, bool generateSharedLib,
bool generateStaticLib, bool generateClientParameters,
bool generateCompilationFeedback,
bool generateCppHeader) {
using Library = mlir::concretelang::CompilerEngine::Library;
auto outputLib = std::make_shared<Library>(outputDirPath, runtimeLibraryPath);
@@ -480,9 +483,9 @@ CompilerEngine::compile(llvm::SourceMgr &sm, std::string outputDirPath,
<< llvm::toString(compilation.takeError());
}
if (auto err = outputLib->emitArtifacts(generateSharedLib, generateStaticLib,
generateClientParameters,
generateCppHeader)) {
if (auto err = outputLib->emitArtifacts(
generateSharedLib, generateStaticLib, generateClientParameters,
generateCompilationFeedback, generateCppHeader)) {
return StreamStringError("Can't emit artifacts: ")
<< llvm::toString(std::move(err));
}
@@ -505,7 +508,7 @@ CompilerEngine::Library::getStaticLibraryPath(std::string outputDirPath) {
return staticLibraryPath.str().str();
}
/// Returns the path of the static library
/// Returns the path of the client parameter
std::string
CompilerEngine::Library::getClientParametersPath(std::string outputDirPath) {
llvm::SmallString<0> clientParametersPath(outputDirPath);
@@ -515,6 +518,14 @@ CompilerEngine::Library::getClientParametersPath(std::string outputDirPath) {
return clientParametersPath.str().str();
}
/// Returns the path of the compiler feedback
std::string
CompilerEngine::Library::getCompilationFeedbackPath(std::string outputDirPath) {
llvm::SmallString<0> compilationFeedbackPath(outputDirPath);
llvm::sys::path::append(compilationFeedbackPath, "compilation_feedback.json");
return compilationFeedbackPath.str().str();
}
const std::string CompilerEngine::Library::OBJECT_EXT = ".o";
const std::string CompilerEngine::Library::LINKER = "ld";
#ifdef __APPLE__
@@ -558,6 +569,26 @@ CompilerEngine::Library::emitClientParametersJSON() {
return clientParamsPath;
}
llvm::Expected<std::string>
CompilerEngine::Library::emitCompilationFeedbackJSON() {
auto path = getCompilationFeedbackPath(outputDirPath);
if (compilationFeedbackList.size() != 1) {
return StreamStringError("multiple compilation feedback not supported");
}
llvm::json::Value value(compilationFeedbackList[0]);
std::error_code error;
llvm::raw_fd_ostream out(path, error);
if (error) {
return StreamStringError("cannot emit client parameters, error: ")
<< error.message();
}
out << llvm::formatv("{0:2}", value);
out.close();
return path;
}
static std::string ccpResultType(size_t rank) {
if (rank == 0) {
return "scalar_out";
@@ -663,6 +694,9 @@ CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
if (compilation.clientParameters.hasValue()) {
clientParametersList.push_back(compilation.clientParameters.getValue());
}
if (compilation.feedback.hasValue()) {
compilationFeedbackList.push_back(compilation.feedback.getValue());
}
return objectPath;
}
@@ -776,6 +810,7 @@ llvm::Expected<std::string> CompilerEngine::Library::emitStatic() {
llvm::Error CompilerEngine::Library::emitArtifacts(bool sharedLib,
bool staticLib,
bool clientParameters,
bool compilationFeedback,
bool cppHeader) {
// Create output directory if doesn't exist
llvm::sys::fs::create_directory(outputDirPath);
@@ -794,6 +829,11 @@ llvm::Error CompilerEngine::Library::emitArtifacts(bool sharedLib,
return err;
}
}
if (compilationFeedback) {
if (auto err = emitCompilationFeedbackJSON().takeError()) {
return err;
}
}
if (cppHeader) {
if (auto err = emitCppHeader().takeError()) {
return err;

View File

@@ -49,11 +49,13 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) {
// Mark the lambda as compiled using DF parallelization
result->lambda->setUseDataflow(options.dataflowParallelize ||
options.autoParallelize);
if (!mlir::concretelang::dfr::_dfr_is_root_node())
if (!mlir::concretelang::dfr::_dfr_is_root_node()) {
result->clientParameters = clientlib::ClientParameters();
else
} else {
result->clientParameters =
compilationResult.get().clientParameters.getValue();
result->feedback = compilationResult.get().feedback.getValue();
}
return std::move(result);
}

View File

@@ -135,6 +135,7 @@ static void display(optimizer::Description &descr,
}
llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
CompilationFeedback &feedback,
optimizer::Config config) {
namespace chrono = std::chrono;
auto start = chrono::high_resolution_clock::now();
@@ -206,6 +207,8 @@ llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
params.largeInteger = lParams;
}
feedback.complexity = sol.complexity;
return params;
}

View File

@@ -585,7 +585,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
if (cmdline::action == Action::COMPILE) {
auto err = outputLib->emitArtifacts(
/*sharedLib=*/true, /*staticLib=*/true,
/*clientParameters=*/true, /*cppHeader=*/true);
/*clientParameters=*/true, /*compilationFeedback=*/true,
/*cppHeader=*/true);
if (err) {
return mlir::failure();
}

View File

@@ -3,6 +3,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "concretelang/Support/CompilationFeedback.h"
#include "concretelang/Support/JITSupport.h"
#include "concretelang/Support/LibrarySupport.h"
#include "end_to_end_fixture/EndToEndFixture.h"
@@ -70,6 +71,10 @@ void compile_and_run_for_config(EndToEndDesc desc, LambdaSupport support,
auto serverLambda = support.loadServerLambda(**compilationResult);
ASSERT_EXPECTED_SUCCESS(serverLambda);
// Just test that we can load the compilation feedback
auto feedback = support.loadCompilationFeedback(**compilationResult);
ASSERT_EXPECTED_SUCCESS(feedback);
assert_all_test_entries(desc, test_error_rate, support, keySet,
evaluationKeys, clientParameters, serverLambda);
}

View File

@@ -26,14 +26,20 @@ def run(engine, args, compilation_result, keyset_cache):
"""Execute engine on the given arguments.
Perform required loading, encryption, execution, and decryption."""
# Dev
compilation_feedback = engine.load_compilation_feedback(
compilation_result)
assert(compilation_feedback is not None)
# Client
client_parameters = engine.load_client_parameters(compilation_result)
key_set = ClientSupport.key_set(client_parameters, keyset_cache)
public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args)
public_arguments = ClientSupport.encrypt_arguments(
client_parameters, key_set, args)
# Server
server_lambda = engine.load_server_lambda(compilation_result)
evaluation_keys = key_set.get_evaluation_keys()
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
public_result = engine.server_call(
server_lambda, public_arguments, evaluation_keys)
# Client
result = ClientSupport.decrypt_result(key_set, public_result)
return result
@@ -135,8 +141,10 @@ end_to_end_parallel_fixture = [
}
""",
(
np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8),
np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]], dtype=np.uint8),
np.array([[1, 2, 3, 4], [4, 2, 1, 0], [
2, 3, 1, 5]], dtype=np.uint8),
np.array([[1, 2, 3, 4], [4, 2, 1, 1], [
2, 3, 1, 5]], dtype=np.uint8),
),
np.array([[52, 36], [31, 34], [42, 52]]),
id="matmul_eint_int_uint8",
@@ -220,7 +228,8 @@ def test_lib_compile_and_run_p_error(keyset_cache):
options = CompilationOptions.new("main")
options.set_p_error(0.00001)
options.set_display_optimizer_choice(True)
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options)
compile_run_assert(engine, mlir_input, args,
expected_result, keyset_cache, options)
def test_lib_compile_and_run_p_error(keyset_cache):