feat(compiler): Output client parameters when compile to a library

close #198
This commit is contained in:
rudy
2021-12-29 11:34:54 +01:00
committed by Quentin Bourgerie
parent a4e8227692
commit b8bd38dd6c
26 changed files with 889 additions and 271 deletions

View File

@@ -3,7 +3,10 @@
// https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license
// information.
#include <fstream>
#include <iostream>
#include <stdio.h>
#include <string>
#include <llvm/Support/Error.h>
#include <llvm/Support/SMLoc.h>
@@ -271,15 +274,23 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return StreamStringError(
"Cannot generate client parameters, the fhe context is empty");
}
}
// Generate client parameters if requested
auto funcName = this->clientParametersFuncName.getValueOr("main");
if (this->generateClientParameters || target == Target::LIBRARY) {
if (!res.fheContext.hasValue()) {
// Some tests can involves a usual function
res.clientParameters =
mlir::concretelang::emptyClientParametersForV0(funcName, module);
} else {
auto clientParametersOrErr =
mlir::concretelang::createClientParametersForV0(*res.fheContext,
funcName, module);
if (!clientParametersOrErr)
return clientParametersOrErr.takeError();
llvm::Expected<mlir::concretelang::ClientParameters> clientParametersOrErr =
mlir::concretelang::createClientParametersForV0(
*res.fheContext, *this->clientParametersFuncName, module);
if (llvm::Error err = clientParametersOrErr.takeError())
return std::move(err);
res.clientParameters = clientParametersOrErr.get();
res.clientParameters = clientParametersOrErr.get();
}
}
// MLIR canonical dialects -> LLVM Dialect
@@ -334,10 +345,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
llvm::Expected<CompilerEngine::CompilationResult>
CompilerEngine::compile(llvm::StringRef s, Target target, OptionalLib lib) {
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
llvm::Expected<CompilationResult> res =
this->compile(std::move(mb), target, lib);
return std::move(res);
return this->compile(std::move(mb), target, lib);
}
// Compile the contained in `buffer` to the target dialect
@@ -351,9 +359,7 @@ CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
llvm::Expected<CompilationResult> res = this->compile(sm, target, lib);
return std::move(res);
return this->compile(sm, target, lib);
}
template <class T>
@@ -371,10 +377,9 @@ CompilerEngine::compile(std::vector<T> inputs, std::string libraryPath) {
}
}
auto libPath = outputLib->emitShared();
if (!libPath) {
return StreamStringError("Can't link: ")
<< llvm::toString(libPath.takeError());
if (auto err = outputLib->emitArtifacts()) {
return StreamStringError("Can't emit artifacts: ")
<< llvm::toString(std::move(err));
}
return *outputLib.get();
}
@@ -384,7 +389,24 @@ template llvm::Expected<CompilerEngine::Library>
CompilerEngine::compile(std::vector<std::string> inputs,
std::string libraryPath);
/** Returns the path of the shared library */
std::string CompilerEngine::Library::getSharedLibraryPath(std::string path) {
return path + DOT_SHARED_LIB_EXT;
}
/** Returns the path of the static library */
std::string CompilerEngine::Library::getStaticLibraryPath(std::string path) {
return path + DOT_STATIC_LIB_EXT;
}
/** Returns the path of the static library */
std::string CompilerEngine::Library::getClientParametersPath(std::string path) {
return path + CLIENT_PARAMETERS_EXT;
}
const std::string CompilerEngine::Library::OBJECT_EXT = ".o";
const std::string CompilerEngine::Library::CLIENT_PARAMETERS_EXT =
".concrete.params.json";
const std::string CompilerEngine::Library::LINKER = "ld";
const std::string CompilerEngine::Library::LINKER_SHARED_OPT = " --shared -o ";
const std::string CompilerEngine::Library::AR = "ar";
@@ -396,6 +418,23 @@ void CompilerEngine::Library::addExtraObjectFilePath(std::string path) {
objectsPath.push_back(path);
}
llvm::Expected<std::string>
CompilerEngine::Library::emitClientParametersJSON() {
auto clientParamsPath = getClientParametersPath(libraryPath);
llvm::json::Value value(clientParametersList);
std::error_code error;
llvm::raw_fd_ostream out(clientParamsPath, error);
if (error) {
return StreamStringError("cannot emit client parameters, error: ")
<< error.message();
}
out << llvm::formatv("{0:2}", value);
out.close();
return clientParamsPath;
}
llvm::Expected<std::string>
CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
llvm::Module *module = compilation.llvmModule.get();
@@ -405,13 +444,14 @@ CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
std::to_string(objectsPath.size()) + ".mlir";
}
auto objectPath = sourceName + OBJECT_EXT;
auto error = mlir::concretelang::emitObject(*module, objectPath);
if (error) {
if (auto error = mlir::concretelang::emitObject(*module, objectPath)) {
return std::move(error);
}
addExtraObjectFilePath(objectPath);
if (compilation.clientParameters.hasValue()) {
clientParametersList.push_back(compilation.clientParameters.getValue());
}
return objectPath;
}
@@ -437,9 +477,8 @@ llvm::Expected<std::string> CompilerEngine::Library::emit(std::string dotExt,
auto error = mlir::concretelang::emitLibrary(objectsPath, pathDotExt, linker);
if (error) {
return std::move(error);
} else {
return pathDotExt;
}
return pathDotExt;
}
llvm::Expected<std::string> CompilerEngine::Library::emitShared() {
@@ -458,6 +497,19 @@ llvm::Expected<std::string> CompilerEngine::Library::emitStatic() {
return path;
}
llvm::Error CompilerEngine::Library::emitArtifacts() {
if (auto err = emitShared().takeError()) {
return err;
}
if (auto err = emitStatic().takeError()) {
return err;
}
if (auto err = emitClientParametersJSON().takeError()) {
return err;
}
return llvm::Error::success();
}
CompilerEngine::Library::~Library() {
if (cleanUp) {
for (auto path : objectsPath) {