mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
when dellocate is used to include dependencies in python wheels, the runtime library will have an id that is prefixed with /DLC, and that path doesn't exist. So when generated libraries won't be able to find it during load time. To solve this, we change the dep in the generated library to be relative to the rpath which should be set correctly during linking. This shouldn't have an impact when /DLC/concrete/.dylibs/* isn't a dependecy in the first place (when not using python). also set rpath when linking to RT lib
728 lines
25 KiB
C++
728 lines
25 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <regex>
|
|
#include <stdio.h>
|
|
#include <string>
|
|
|
|
#include <llvm/Support/Error.h>
|
|
#include <llvm/Support/Path.h>
|
|
#include <llvm/Support/SMLoc.h>
|
|
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
|
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
|
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
|
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
|
|
#include <mlir/Dialect/SCF/SCF.h>
|
|
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
|
#include <mlir/ExecutionEngine/OptUtils.h>
|
|
#include <mlir/Parser.h>
|
|
|
|
#include <concretelang/ClientLib/ClientParameters.h>
|
|
#include <concretelang/Dialect/BConcrete/IR/BConcreteDialect.h>
|
|
#include <concretelang/Dialect/Concrete/IR/ConcreteDialect.h>
|
|
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
|
#include <concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h>
|
|
#include <concretelang/Dialect/RT/IR/RTDialect.h>
|
|
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
|
|
#include <concretelang/Support/CompilerEngine.h>
|
|
#include <concretelang/Support/Error.h>
|
|
#include <concretelang/Support/Jit.h>
|
|
#include <concretelang/Support/LLVMEmitFile.h>
|
|
#include <concretelang/Support/Pipeline.h>
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
|
|
// Creates a new compilation context that can be shared across
|
|
// compilation engines and results
|
|
std::shared_ptr<CompilationContext> CompilationContext::createShared() {
|
|
return std::make_shared<CompilationContext>();
|
|
}
|
|
|
|
CompilationContext::CompilationContext()
|
|
: mlirContext(nullptr), llvmContext(nullptr) {}
|
|
|
|
CompilationContext::~CompilationContext() {
|
|
delete this->mlirContext;
|
|
delete this->llvmContext;
|
|
}
|
|
|
|
// Returns the MLIR context for a compilation context. Creates and
|
|
// initializes a new MLIR context if necessary.
|
|
mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
|
if (this->mlirContext == nullptr) {
|
|
this->mlirContext = new mlir::MLIRContext();
|
|
|
|
this->mlirContext->getOrLoadDialect<mlir::concretelang::RT::RTDialect>();
|
|
this->mlirContext->getOrLoadDialect<mlir::concretelang::FHE::FHEDialect>();
|
|
this->mlirContext
|
|
->getOrLoadDialect<mlir::concretelang::TFHE::TFHEDialect>();
|
|
this->mlirContext
|
|
->getOrLoadDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();
|
|
this->mlirContext
|
|
->getOrLoadDialect<mlir::concretelang::Concrete::ConcreteDialect>();
|
|
this->mlirContext
|
|
->getOrLoadDialect<mlir::concretelang::BConcrete::BConcreteDialect>();
|
|
this->mlirContext->getOrLoadDialect<mlir::StandardOpsDialect>();
|
|
this->mlirContext->getOrLoadDialect<mlir::memref::MemRefDialect>();
|
|
this->mlirContext->getOrLoadDialect<mlir::linalg::LinalgDialect>();
|
|
this->mlirContext->getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
|
this->mlirContext->getOrLoadDialect<mlir::scf::SCFDialect>();
|
|
this->mlirContext->getOrLoadDialect<mlir::omp::OpenMPDialect>();
|
|
}
|
|
|
|
return this->mlirContext;
|
|
}
|
|
|
|
// Returns the LLVM context for a compilation context. Creates and
|
|
// initializes a new LLVM context if necessary.
|
|
llvm::LLVMContext *CompilationContext::getLLVMContext() {
|
|
if (this->llvmContext == nullptr)
|
|
this->llvmContext = new llvm::LLVMContext();
|
|
|
|
return this->llvmContext;
|
|
}
|
|
|
|
// Sets the FHE constraints for the compilation. Overrides any
|
|
// automatically detected configuration and prevents the autodetection
|
|
// pass from running.
|
|
void CompilerEngine::setFHEConstraints(
|
|
const mlir::concretelang::V0FHEConstraint &c) {
|
|
this->overrideMaxEintPrecision = c.p;
|
|
this->overrideMaxMANP = c.norm2;
|
|
}
|
|
|
|
void CompilerEngine::setVerifyDiagnostics(bool v) {
|
|
this->verifyDiagnostics = v;
|
|
}
|
|
|
|
void CompilerEngine::setAutoParallelize(bool v) { this->autoParallelize = v; }
|
|
|
|
void CompilerEngine::setLoopParallelize(bool v) { this->loopParallelize = v; }
|
|
|
|
void CompilerEngine::setDataflowParallelize(bool v) {
|
|
this->dataflowParallelize = v;
|
|
}
|
|
|
|
void CompilerEngine::setGenerateClientParameters(bool v) {
|
|
this->generateClientParameters = v;
|
|
}
|
|
|
|
void CompilerEngine::setMaxEintPrecision(size_t v) {
|
|
this->overrideMaxEintPrecision = v;
|
|
}
|
|
|
|
void CompilerEngine::setMaxMANP(size_t v) { this->overrideMaxMANP = v; }
|
|
|
|
void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) {
|
|
this->clientParametersFuncName = name.str();
|
|
}
|
|
|
|
void CompilerEngine::setFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes) {
|
|
this->fhelinalgTileSizes = sizes.vec();
|
|
}
|
|
|
|
void CompilerEngine::setEnablePass(
|
|
std::function<bool(mlir::Pass *)> enablePass) {
|
|
this->enablePass = enablePass;
|
|
}
|
|
|
|
// Returns the overwritten V0FHEConstraint or try to compute them from FHE
|
|
llvm::Expected<llvm::Optional<mlir::concretelang::V0FHEConstraint>>
|
|
CompilerEngine::getV0FHEConstraint(CompilationResult &res) {
|
|
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
|
mlir::ModuleOp module = res.mlirModuleRef->get();
|
|
// If the values has been overwritten returns
|
|
if (this->overrideMaxEintPrecision.hasValue() &&
|
|
this->overrideMaxMANP.hasValue()) {
|
|
return mlir::concretelang::V0FHEConstraint{
|
|
this->overrideMaxMANP.getValue(),
|
|
this->overrideMaxEintPrecision.getValue()};
|
|
}
|
|
// Else compute constraint from FHE
|
|
llvm::Expected<llvm::Optional<mlir::concretelang::V0FHEConstraint>>
|
|
fheConstraintsOrErr =
|
|
mlir::concretelang::pipeline::getFHEConstraintsFromFHE(
|
|
mlirContext, module, enablePass);
|
|
|
|
if (auto err = fheConstraintsOrErr.takeError())
|
|
return std::move(err);
|
|
|
|
return fheConstraintsOrErr.get();
|
|
}
|
|
|
|
// set the fheContext field if the v0Constraint can be computed
|
|
llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
|
auto fheConstraintOrErr = getV0FHEConstraint(res);
|
|
if (auto err = fheConstraintOrErr.takeError())
|
|
return err;
|
|
if (!fheConstraintOrErr.get().hasValue()) {
|
|
return llvm::Error::success();
|
|
}
|
|
const mlir::concretelang::V0Parameter *fheParams =
|
|
getV0Parameter(fheConstraintOrErr.get().getValue());
|
|
|
|
if (!fheParams) {
|
|
return StreamStringError()
|
|
<< "Could not determine V0 parameters for 2-norm of "
|
|
<< (*fheConstraintOrErr)->norm2 << " and p of "
|
|
<< (*fheConstraintOrErr)->p;
|
|
}
|
|
res.fheContext.emplace(mlir::concretelang::V0FHEContext{
|
|
(*fheConstraintOrErr).getValue(), *fheParams});
|
|
|
|
return llvm::Error::success();
|
|
}
|
|
|
|
using OptionalLib = llvm::Optional<std::shared_ptr<CompilerEngine::Library>>;
|
|
// Compile the sources managed by the source manager `sm` to the
|
|
// target dialect `target`. If successful, the result can be retrieved
|
|
// using `getModule()` and `getLLVMModule()`, respectively depending
|
|
// on the target dialect.
|
|
llvm::Expected<CompilerEngine::CompilationResult>
|
|
CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
|
std::string diagnosticsMsg;
|
|
llvm::raw_string_ostream diagnosticsOS(diagnosticsMsg);
|
|
auto errorDiag = [&](std::string prefixMsg)
|
|
-> llvm::Expected<CompilerEngine::CompilationResult> {
|
|
return StreamStringError(prefixMsg + "\n" + diagnosticsOS.str());
|
|
};
|
|
|
|
CompilationResult res(this->compilationContext);
|
|
|
|
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
|
|
|
mlir::SourceMgrDiagnosticVerifierHandler smHandler(sm, &mlirContext,
|
|
diagnosticsOS);
|
|
mlirContext.printOpOnDiagnostic(false);
|
|
|
|
mlir::OwningModuleRef mlirModuleRef =
|
|
mlir::parseSourceFile<mlir::ModuleOp>(sm, &mlirContext);
|
|
|
|
if (this->verifyDiagnostics) {
|
|
if (smHandler.verify().failed())
|
|
return StreamStringError("Verification of diagnostics failed");
|
|
else
|
|
return std::move(res);
|
|
}
|
|
|
|
if (!mlirModuleRef) {
|
|
return errorDiag("Could not parse source");
|
|
}
|
|
|
|
res.mlirModuleRef = std::move(mlirModuleRef);
|
|
mlir::ModuleOp module = res.mlirModuleRef->get();
|
|
|
|
if (target == Target::ROUND_TRIP)
|
|
return std::move(res);
|
|
|
|
// FHE High level pass to determine FHE parameters
|
|
if (auto err = this->determineFHEParameters(res))
|
|
return std::move(err);
|
|
|
|
// FHELinalg tiling
|
|
if (this->fhelinalgTileSizes) {
|
|
if (mlir::concretelang::pipeline::markFHELinalgForTiling(
|
|
mlirContext, module, *this->fhelinalgTileSizes, enablePass)
|
|
.failed())
|
|
return errorDiag("Marking of FHELinalg operations for tiling failed");
|
|
}
|
|
|
|
if (mlir::concretelang::pipeline::tileMarkedFHELinalg(mlirContext, module,
|
|
enablePass)
|
|
.failed()) {
|
|
return errorDiag("Tiling of FHELinalg operations failed");
|
|
}
|
|
|
|
// Dataflow parallelization
|
|
if ((this->autoParallelize || this->dataflowParallelize) &&
|
|
mlir::concretelang::pipeline::autopar(mlirContext, module, enablePass)
|
|
.failed()) {
|
|
return StreamStringError("Dataflow parallelization failed");
|
|
}
|
|
|
|
if (target == Target::FHE)
|
|
return std::move(res);
|
|
|
|
// FHE -> TFHE
|
|
if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module,
|
|
enablePass)
|
|
.failed()) {
|
|
return errorDiag("Lowering from FHE to TFHE failed");
|
|
}
|
|
if (target == Target::TFHE)
|
|
return std::move(res);
|
|
|
|
// TFHE -> Concrete
|
|
if (mlir::concretelang::pipeline::lowerTFHEToConcrete(
|
|
mlirContext, module, res.fheContext, this->enablePass)
|
|
.failed()) {
|
|
return errorDiag("Lowering from TFHE to Concrete failed");
|
|
}
|
|
if (target == Target::CONCRETE)
|
|
return std::move(res);
|
|
|
|
// Generate client parameters if requested
|
|
if (this->generateClientParameters) {
|
|
if (!this->clientParametersFuncName.hasValue()) {
|
|
return StreamStringError(
|
|
"Generation of client parameters requested, but no function name "
|
|
"specified");
|
|
}
|
|
if (!res.fheContext.hasValue()) {
|
|
return StreamStringError(
|
|
"Cannot generate client parameters, the fhe context is empty");
|
|
}
|
|
}
|
|
// Generate client parameters if requested
|
|
auto funcName = this->clientParametersFuncName.getValueOr("main");
|
|
if (this->generateClientParameters || target == Target::LIBRARY) {
|
|
if (!res.fheContext.hasValue()) {
|
|
// Some tests involve call a to non encrypted functions
|
|
ClientParameters emptyParams;
|
|
emptyParams.functionName = funcName;
|
|
res.clientParameters = emptyParams;
|
|
} else {
|
|
auto clientParametersOrErr =
|
|
mlir::concretelang::createClientParametersForV0(*res.fheContext,
|
|
funcName, module);
|
|
if (!clientParametersOrErr)
|
|
return clientParametersOrErr.takeError();
|
|
|
|
res.clientParameters = clientParametersOrErr.get();
|
|
}
|
|
}
|
|
|
|
// Concrete -> BConcrete
|
|
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
|
mlirContext, module, this->enablePass,
|
|
this->loopParallelize || this->autoParallelize)
|
|
.failed()) {
|
|
return StreamStringError(
|
|
"Lowering from Concrete to Bufferized Concrete failed");
|
|
}
|
|
|
|
if (target == Target::BCONCRETE) {
|
|
return std::move(res);
|
|
}
|
|
|
|
// BConcrete -> Canonical dialects
|
|
if (mlir::concretelang::pipeline::lowerBConcreteToStd(mlirContext, module,
|
|
enablePass)
|
|
.failed()) {
|
|
return errorDiag(
|
|
"Lowering from Bufferized Concrete to canonical MLIR dialects failed");
|
|
}
|
|
if (target == Target::STD)
|
|
return std::move(res);
|
|
|
|
// MLIR canonical dialects -> LLVM Dialect
|
|
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
|
|
mlirContext, module, enablePass,
|
|
this->loopParallelize || this->autoParallelize)
|
|
.failed()) {
|
|
return errorDiag("Failed to lower to LLVM dialect");
|
|
}
|
|
|
|
if (target == Target::LLVM)
|
|
return std::move(res);
|
|
|
|
// Lowering to actual LLVM IR (i.e., not the LLVM dialect)
|
|
llvm::LLVMContext &llvmContext = *this->compilationContext->getLLVMContext();
|
|
|
|
res.llvmModule = mlir::concretelang::pipeline::lowerLLVMDialectToLLVMIR(
|
|
mlirContext, llvmContext, module);
|
|
|
|
if (!res.llvmModule)
|
|
return StreamStringError("Failed to convert from LLVM dialect to LLVM IR");
|
|
|
|
if (target == Target::LLVM_IR)
|
|
return std::move(res);
|
|
|
|
if (mlir::concretelang::pipeline::optimizeLLVMModule(llvmContext,
|
|
*res.llvmModule)
|
|
.failed()) {
|
|
return errorDiag("Failed to optimize LLVM IR");
|
|
}
|
|
|
|
if (target == Target::OPTIMIZED_LLVM_IR)
|
|
return std::move(res);
|
|
|
|
if (target == Target::LIBRARY) {
|
|
if (!lib) {
|
|
return StreamStringError(
|
|
"Internal Error: Please provide a library parameter");
|
|
}
|
|
auto objPath = lib.getValue()->addCompilation(res);
|
|
if (!objPath) {
|
|
return StreamStringError(llvm::toString(objPath.takeError()));
|
|
}
|
|
return std::move(res);
|
|
}
|
|
|
|
return std::move(res);
|
|
}
|
|
|
|
// Compile the source `s` to the target dialect `target`. If successful, the
|
|
// result can be retrieved using `getModule()` and `getLLVMModule()`,
|
|
// respectively depending on the target dialect.
|
|
llvm::Expected<CompilerEngine::CompilationResult>
|
|
CompilerEngine::compile(llvm::StringRef s, Target target, OptionalLib lib) {
|
|
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
|
|
return this->compile(std::move(mb), target, lib);
|
|
}
|
|
|
|
// Compile the contained in `buffer` to the target dialect
|
|
// `target`. If successful, the result can be retrieved using
|
|
// `getModule()` and `getLLVMModule()`, respectively depending on the
|
|
// target dialect.
|
|
llvm::Expected<CompilerEngine::CompilationResult>
|
|
CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
|
Target target, OptionalLib lib) {
|
|
llvm::SourceMgr sm;
|
|
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
|
|
|
return this->compile(sm, target, lib);
|
|
}
|
|
|
|
llvm::Expected<CompilerEngine::Library>
|
|
CompilerEngine::compile(std::vector<std::string> inputs,
|
|
std::string libraryPath,
|
|
std::string runtimeLibraryPath) {
|
|
using Library = mlir::concretelang::CompilerEngine::Library;
|
|
auto outputLib = std::make_shared<Library>(libraryPath, runtimeLibraryPath);
|
|
auto target = CompilerEngine::Target::LIBRARY;
|
|
for (auto input : inputs) {
|
|
auto compilation = compile(input, target, outputLib);
|
|
if (!compilation) {
|
|
return StreamStringError("Can't compile: ")
|
|
<< llvm::toString(compilation.takeError());
|
|
}
|
|
}
|
|
if (auto err = outputLib->emitArtifacts()) {
|
|
return StreamStringError("Can't emit artifacts: ")
|
|
<< llvm::toString(std::move(err));
|
|
}
|
|
return *outputLib.get();
|
|
}
|
|
|
|
llvm::Expected<CompilerEngine::Library>
|
|
CompilerEngine::compile(llvm::SourceMgr &sm, std::string libraryPath,
|
|
std::string runtimeLibraryPath) {
|
|
using Library = mlir::concretelang::CompilerEngine::Library;
|
|
auto outputLib = std::make_shared<Library>(libraryPath, runtimeLibraryPath);
|
|
auto target = CompilerEngine::Target::LIBRARY;
|
|
|
|
auto compilation = compile(sm, target, outputLib);
|
|
if (!compilation) {
|
|
return StreamStringError("Can't compile: ")
|
|
<< llvm::toString(compilation.takeError());
|
|
}
|
|
|
|
if (auto err = outputLib->emitArtifacts()) {
|
|
return StreamStringError("Can't emit artifacts: ")
|
|
<< llvm::toString(std::move(err));
|
|
}
|
|
return *outputLib.get();
|
|
}
|
|
|
|
/** Returns the path of the shared library */
|
|
std::string CompilerEngine::Library::getSharedLibraryPath(std::string path) {
|
|
return path + DOT_SHARED_LIB_EXT;
|
|
}
|
|
|
|
/** Returns the path of the static library */
|
|
std::string CompilerEngine::Library::getStaticLibraryPath(std::string path) {
|
|
return path + DOT_STATIC_LIB_EXT;
|
|
}
|
|
|
|
/** Returns the path of the static library */
|
|
std::string CompilerEngine::Library::getClientParametersPath(std::string path) {
|
|
return ClientParameters::getClientParametersPath(path);
|
|
}
|
|
|
|
const std::string CompilerEngine::Library::OBJECT_EXT = ".o";
|
|
const std::string CompilerEngine::Library::CLIENT_PARAMETERS_EXT =
|
|
".concrete.params.json";
|
|
const std::string CompilerEngine::Library::LINKER = "ld";
|
|
#ifdef __APPLE__
|
|
// We need to tell the linker that some symbols will be missing during linking,
|
|
// this symbols should be available during runtime however. This is the case
|
|
// when JIT compiling, the JIT should either link to the runtime library that
|
|
// has the missing symbols, or it would have been loaded even prior to that
|
|
const std::string CompilerEngine::Library::LINKER_SHARED_OPT =
|
|
" -dylib -undefined dynamic_lookup -o ";
|
|
const std::string CompilerEngine::Library::DOT_SHARED_LIB_EXT = ".dylib";
|
|
#else // Linux
|
|
const std::string CompilerEngine::Library::LINKER_SHARED_OPT = " --shared -o ";
|
|
const std::string CompilerEngine::Library::DOT_SHARED_LIB_EXT = ".so";
|
|
#endif
|
|
const std::string CompilerEngine::Library::AR = "ar";
|
|
const std::string CompilerEngine::Library::AR_STATIC_OPT = " rcs ";
|
|
const std::string CompilerEngine::Library::DOT_STATIC_LIB_EXT = ".a";
|
|
|
|
void CompilerEngine::Library::addExtraObjectFilePath(std::string path) {
|
|
objectsPath.push_back(path);
|
|
}
|
|
|
|
llvm::Expected<std::string>
|
|
CompilerEngine::Library::emitClientParametersJSON() {
|
|
auto clientParamsPath = getClientParametersPath(libraryPath);
|
|
llvm::json::Value value(clientParametersList);
|
|
std::error_code error;
|
|
llvm::raw_fd_ostream out(clientParamsPath, error);
|
|
|
|
if (error) {
|
|
return StreamStringError("cannot emit client parameters, error: ")
|
|
<< error.message();
|
|
}
|
|
out << llvm::formatv("{0:2}", value);
|
|
out.close();
|
|
|
|
return clientParamsPath;
|
|
}
|
|
|
|
static std::string ccpResultType(size_t rank) {
|
|
if (rank == 0) {
|
|
return "scalar_out";
|
|
} else {
|
|
return "tensor" + std::to_string(rank) + "_out";
|
|
}
|
|
}
|
|
|
|
static std::string ccpArgType(size_t rank) {
|
|
if (rank == 0) {
|
|
return "scalar_in";
|
|
} else {
|
|
return "tensor" + std::to_string(rank) + "_in";
|
|
}
|
|
}
|
|
|
|
static std::string cppArgsType(std::vector<CircuitGate> inputs) {
|
|
std::string args;
|
|
for (auto input : inputs) {
|
|
if (!args.empty()) {
|
|
args += ", ";
|
|
}
|
|
args += ccpArgType(input.shape.dimensions.size());
|
|
}
|
|
return args;
|
|
}
|
|
|
|
llvm::Expected<std::string> CompilerEngine::Library::emitCppHeader() {
|
|
auto libraryName = llvm::sys::path::filename(libraryPath).str();
|
|
auto headerName = libraryName + "-client.h";
|
|
auto headerPath = std::regex_replace(
|
|
libraryPath, std::regex(libraryName + "$"), headerName);
|
|
|
|
std::error_code error;
|
|
llvm::raw_fd_ostream out(headerPath, error);
|
|
if (error) {
|
|
StreamStringError("Cannot emit header: ")
|
|
<< headerPath << ", " << error.message() << "\n";
|
|
}
|
|
|
|
out << "#include \"boost/outcome.h\"\n";
|
|
out << "#include \"concretelang/ClientLib/ClientLambda.h\"\n";
|
|
out << "#include \"concretelang/ClientLib/KeySetCache.h\"\n";
|
|
out << "#include \"concretelang/ClientLib/Types.h\"\n";
|
|
out << "#include \"concretelang/Common/Error.h\"\n";
|
|
out << "\n";
|
|
out << "namespace " << libraryName << " {\n";
|
|
out << "namespace client {\n";
|
|
|
|
for (auto params : clientParametersList) {
|
|
std::string args;
|
|
std::string result;
|
|
if (params.outputs.size() > 0) {
|
|
args = cppArgsType(params.inputs);
|
|
} else {
|
|
args = "void";
|
|
}
|
|
if (params.outputs.size() > 0) {
|
|
size_t rank = params.outputs[0].shape.dimensions.size();
|
|
result = ccpResultType(rank);
|
|
} else {
|
|
result = "void";
|
|
}
|
|
out << "\n";
|
|
out << "namespace " << params.functionName << " {\n";
|
|
out << " using namespace concretelang::clientlib;\n";
|
|
out << " using concretelang::error::StringError;\n";
|
|
out << " using " << params.functionName << "_t = TypedClientLambda<"
|
|
<< result << ", " << args << ">;\n";
|
|
out << " static const std::string name = \"" << params.functionName
|
|
<< "\";\n";
|
|
out << "\n";
|
|
out << " static outcome::checked<" << params.functionName
|
|
<< "_t, StringError>\n";
|
|
out << " load(std::string outputLib)\n";
|
|
out << " { return " << params.functionName
|
|
<< "_t::load(name, outputLib); }\n";
|
|
out << "} // namespace " << params.functionName << "\n";
|
|
}
|
|
out << "\n";
|
|
out << "} // namespace client\n";
|
|
out << "} // namespace " << libraryName << "\n";
|
|
|
|
out.close();
|
|
|
|
return headerPath;
|
|
}
|
|
|
|
llvm::Expected<std::string>
|
|
CompilerEngine::Library::addCompilation(CompilationResult &compilation) {
|
|
llvm::Module *module = compilation.llvmModule.get();
|
|
auto sourceName = module->getSourceFileName();
|
|
if (sourceName == "" || sourceName == "LLVMDialectModule") {
|
|
sourceName = this->libraryPath + ".module-" +
|
|
std::to_string(objectsPath.size()) + ".mlir";
|
|
}
|
|
auto objectPath = sourceName + OBJECT_EXT;
|
|
if (auto error = mlir::concretelang::emitObject(*module, objectPath)) {
|
|
return std::move(error);
|
|
}
|
|
|
|
addExtraObjectFilePath(objectPath);
|
|
if (compilation.clientParameters.hasValue()) {
|
|
clientParametersList.push_back(compilation.clientParameters.getValue());
|
|
}
|
|
return objectPath;
|
|
}
|
|
|
|
bool stringEndsWith(std::string path, std::string requiredExt) {
|
|
return path.substr(path.size() - requiredExt.size()) == requiredExt;
|
|
}
|
|
|
|
std::string removeDotExt(std::string path, std::string dotExt) {
|
|
return (stringEndsWith(path, dotExt))
|
|
? path.substr(0, path.size() - dotExt.size())
|
|
: path;
|
|
}
|
|
|
|
std::string ensureLibDotExt(std::string path, std::string dotExt) {
|
|
path = removeDotExt(path, CompilerEngine::Library::DOT_STATIC_LIB_EXT);
|
|
path = removeDotExt(path, CompilerEngine::Library::DOT_SHARED_LIB_EXT);
|
|
return path + dotExt;
|
|
}
|
|
|
|
llvm::Expected<std::string> CompilerEngine::Library::emit(
|
|
std::string dotExt, std::string linker,
|
|
llvm::Optional<std::vector<std::string>> extraArgs) {
|
|
auto pathDotExt = ensureLibDotExt(libraryPath, dotExt);
|
|
auto error = mlir::concretelang::emitLibrary(objectsPath, pathDotExt, linker,
|
|
extraArgs);
|
|
if (error) {
|
|
return std::move(error);
|
|
}
|
|
return pathDotExt;
|
|
}
|
|
|
|
llvm::Expected<std::string> CompilerEngine::Library::emitShared() {
|
|
std::vector<std::string> extraArgs;
|
|
std::string fullRuntimeLibraryName = "";
|
|
#ifdef __APPLE__
|
|
// to issue the command for fixing the runtime dependency of the generated lib
|
|
bool fixRuntimeDep = false;
|
|
#endif
|
|
if (!runtimeLibraryPath.empty()) {
|
|
// Getting the parent dir should work on Linux and Mac
|
|
std::size_t rpathLastPos = runtimeLibraryPath.find_last_of("/");
|
|
std::string rpath = "";
|
|
std::string runtimeLibraryName = "";
|
|
if (rpathLastPos != std::string::npos) {
|
|
rpath = runtimeLibraryPath.substr(0, rpathLastPos);
|
|
fullRuntimeLibraryName = runtimeLibraryPath.substr(
|
|
rpathLastPos + 1, runtimeLibraryPath.length());
|
|
// runtimeLibraryName is part of fullRuntimeLibraryName =
|
|
// lib(runtimeLibraryName).dylib
|
|
runtimeLibraryName =
|
|
removeDotExt(fullRuntimeLibraryName, DOT_SHARED_LIB_EXT);
|
|
if (runtimeLibraryName.rfind("lib", 0) == 0) { // starts with lib
|
|
runtimeLibraryName =
|
|
runtimeLibraryName.substr(3, runtimeLibraryName.length());
|
|
}
|
|
}
|
|
#ifdef __APPLE__
|
|
if (!rpath.empty() && !runtimeLibraryName.empty()) {
|
|
fixRuntimeDep = true;
|
|
extraArgs.push_back("-l" + runtimeLibraryName);
|
|
extraArgs.push_back("-L" + rpath);
|
|
extraArgs.push_back("-rpath " + rpath);
|
|
}
|
|
#else // Linux
|
|
extraArgs.push_back(runtimeLibraryPath);
|
|
if (!rpath.empty()) {
|
|
extraArgs.push_back("-rpath=" + rpath);
|
|
// Use RPATH instead of RUNPATH for transitive dependencies
|
|
extraArgs.push_back("--disable-new-dtags");
|
|
}
|
|
#endif
|
|
}
|
|
auto path = emit(DOT_SHARED_LIB_EXT, LINKER + LINKER_SHARED_OPT, extraArgs);
|
|
if (path) {
|
|
sharedLibraryPath = path.get();
|
|
#ifdef __APPLE__
|
|
// when dellocate is used to include dependencies in python wheels, the
|
|
// runtime library will have an id that is prefixed with /DLC, and that path
|
|
// doesn't exist. So when generated libraries won't be able to find it
|
|
// during load time. To solve this, we change the dep in the generated
|
|
// library to be relative to the rpath which should be set correctly during
|
|
// linking. This shouldn't have an impact when /DLC/concrete/.dylibs/* isn't
|
|
// a dependecy in the first place (when not using python).
|
|
if (fixRuntimeDep) {
|
|
std::string fixRuntimeDepCmd = "install_name_tool -change "
|
|
"/DLC/concrete/.dylibs/" +
|
|
fullRuntimeLibraryName + " @rpath/" +
|
|
fullRuntimeLibraryName + " " +
|
|
sharedLibraryPath;
|
|
auto error = mlir::concretelang::callCmd(fixRuntimeDepCmd);
|
|
if (error) {
|
|
return std::move(error);
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
|
|
return path;
|
|
}
|
|
|
|
llvm::Expected<std::string> CompilerEngine::Library::emitStatic() {
|
|
auto path = emit(DOT_STATIC_LIB_EXT, AR + AR_STATIC_OPT);
|
|
if (path) {
|
|
staticLibraryPath = path.get();
|
|
}
|
|
return path;
|
|
}
|
|
|
|
llvm::Error CompilerEngine::Library::emitArtifacts() {
|
|
if (auto err = emitShared().takeError()) {
|
|
return err;
|
|
}
|
|
if (auto err = emitStatic().takeError()) {
|
|
return err;
|
|
}
|
|
if (auto err = emitClientParametersJSON().takeError()) {
|
|
return err;
|
|
}
|
|
if (auto err = emitCppHeader().takeError()) {
|
|
return err;
|
|
}
|
|
return llvm::Error::success();
|
|
}
|
|
|
|
CompilerEngine::Library::~Library() {
|
|
if (cleanUp) {
|
|
for (auto path : objectsPath) {
|
|
remove(path.c_str());
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace concretelang
|
|
} // namespace mlir
|