feat(rust): compile mlir into library

CAPI covering a wider API of the Support library.
Better error handling. Could also be improved by returning an error
message back from C to rust (left TODO).
This commit is contained in:
youben11
2022-11-24 10:42:18 +01:00
committed by Ayoub Benaissa
parent 8a557368f1
commit b00115f4ae
5 changed files with 344 additions and 20 deletions

View File

@@ -6,9 +6,58 @@
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/CAPI/Wrappers.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/Support/SourceMgr.h"
/// CompilerEngine CAPI
/// ********** CompilationOptions CAPI *****************************************
CompilationOptions
compilationOptionsCreate(MlirStringRef funcName, bool autoParallelize,
bool batchConcreteOps, bool dataflowParallelize,
bool emitGPUOps, bool loopParallelize,
bool optimizeConcrete, OptimizerConfig optimizerConfig,
bool verifyDiagnostics) {
std::string funcNameStr(funcName.data, funcName.length);
auto options = new mlir::concretelang::CompilationOptions(funcNameStr);
options->autoParallelize = autoParallelize;
options->batchConcreteOps = batchConcreteOps;
options->dataflowParallelize = dataflowParallelize;
options->emitGPUOps = emitGPUOps;
options->loopParallelize = loopParallelize;
options->optimizeConcrete = optimizeConcrete;
options->optimizerConfig = *unwrap(optimizerConfig);
options->verifyDiagnostics = verifyDiagnostics;
return wrap(options);
}
CompilationOptions compilationOptionsCreateDefault() {
return wrap(new mlir::concretelang::CompilationOptions("main"));
}
/// ********** OptimizerConfig CAPI ********************************************
OptimizerConfig optimizerConfigCreate(bool display,
double fallback_log_norm_woppbs,
double global_p_error, double p_error,
uint64_t security, bool strategy_v0,
bool use_gpu_constraints) {
auto config = new mlir::concretelang::optimizer::Config();
config->display = display;
config->fallback_log_norm_woppbs = fallback_log_norm_woppbs;
config->global_p_error = global_p_error;
config->p_error = p_error;
config->security = security;
config->strategy_v0 = strategy_v0;
config->use_gpu_constraints = use_gpu_constraints;
return wrap(config);
}
OptimizerConfig optimizerConfigCreateDefault() {
return wrap(new mlir::concretelang::optimizer::Config());
}
/// ********** CompilerEngine CAPI *********************************************
CompilerEngine compilerEngineCreate() {
auto *engine = new mlir::concretelang::CompilerEngine(
@@ -18,30 +67,61 @@ CompilerEngine compilerEngineCreate() {
void compilerEngineDestroy(CompilerEngine engine) { delete unwrap(engine); }
/// Map C compilationTarget to Cpp
llvm::Expected<mlir::concretelang::CompilerEngine::Target>
targetConvertToCppFromC(CompilationTarget target) {
switch (target) {
case ROUND_TRIP:
return mlir::concretelang::CompilerEngine::Target::ROUND_TRIP;
case FHE:
return mlir::concretelang::CompilerEngine::Target::FHE;
case TFHE:
return mlir::concretelang::CompilerEngine::Target::TFHE;
case CONCRETE:
return mlir::concretelang::CompilerEngine::Target::CONCRETE;
case CONCRETEWITHLOOPS:
return mlir::concretelang::CompilerEngine::Target::CONCRETEWITHLOOPS;
case BCONCRETE:
return mlir::concretelang::CompilerEngine::Target::BCONCRETE;
case STD:
return mlir::concretelang::CompilerEngine::Target::STD;
case LLVM:
return mlir::concretelang::CompilerEngine::Target::LLVM;
case LLVM_IR:
return mlir::concretelang::CompilerEngine::Target::LLVM_IR;
case OPTIMIZED_LLVM_IR:
return mlir::concretelang::CompilerEngine::Target::OPTIMIZED_LLVM_IR;
case LIBRARY:
return mlir::concretelang::CompilerEngine::Target::LIBRARY;
}
return mlir::concretelang::StreamStringError("invalid compilation target");
}
CompilationResult compilerEngineCompile(CompilerEngine engine,
MlirStringRef module,
CompilationTarget target) {
std::string module_str(module.data, module.length);
if (target == ROUND_TRIP) {
auto retOrError = unwrap(engine)->compile(
module_str, mlir::concretelang::CompilerEngine::Target::ROUND_TRIP);
if (!retOrError) {
// TODO: access the MlirContext
// mlir::emitError(mlir::UnknownLoc::get(unwrap(engine)) << "azeza";
return wrap(
(mlir::concretelang::CompilerEngine::CompilationResult *)nullptr);
}
return wrap(new mlir::concretelang::CompilerEngine::CompilationResult(
std::move(retOrError.get())));
auto targetCppOrError = targetConvertToCppFromC(target);
if (!targetCppOrError) { // invalid target
llvm::errs() << llvm::toString(targetCppOrError.takeError());
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL);
}
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)nullptr);
auto retOrError = unwrap(engine)->compile(module_str, targetCppOrError.get());
if (!retOrError) { // compilation error
llvm::errs() << llvm::toString(retOrError.takeError());
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL);
}
return wrap(new mlir::concretelang::CompilerEngine::CompilationResult(
std::move(retOrError.get())));
}
/// CompilationResult CAPI
void compilationResultDestroy(CompilationResult result) {
delete unwrap(result);
void compilerEngineCompileSetOptions(CompilerEngine engine,
CompilationOptions options) {
unwrap(engine)->setCompilationOptions(*unwrap(options));
}
/// ********** CompilationResult CAPI ******************************************
MlirStringRef compilationResultGetModuleString(CompilationResult result) {
// print the module into a string
std::string moduleString;
@@ -56,3 +136,55 @@ MlirStringRef compilationResultGetModuleString(CompilationResult result) {
void compilationResultDestroyModuleString(MlirStringRef str) {
delete str.data;
}
void compilationResultDestroy(CompilationResult result) {
delete unwrap(result);
}
/// ********** Library CAPI ****************************************************
Library libraryCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath, bool cleanUp) {
std::string outputDirPathStr(outputDirPath.data, outputDirPath.length);
std::string runtimeLibraryPathStr(runtimeLibraryPath.data,
runtimeLibraryPath.length);
return wrap(new mlir::concretelang::CompilerEngine::Library(
outputDirPathStr, runtimeLibraryPathStr, cleanUp));
}
void libraryDestroy(Library lib) { delete unwrap(lib); }
/// ********** LibraryCompilationResult CAPI ***********************************
void libraryCompilationResultDestroy(LibraryCompilationResult result) {
delete unwrap(result);
}
/// ********** LibrarySupport CAPI *********************************************
LibrarySupport
librarySupportCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath, bool generateSharedLib,
bool generateStaticLib, bool generateClientParameters,
bool generateCompilationFeedback, bool generateCppHeader) {
std::string outputDirPathStr(outputDirPath.data, outputDirPath.length);
std::string runtimeLibraryPathStr(runtimeLibraryPath.data,
runtimeLibraryPath.length);
return wrap(new mlir::concretelang::LibrarySupport(
outputDirPathStr, runtimeLibraryPathStr, generateSharedLib,
generateStaticLib, generateClientParameters, generateCompilationFeedback,
generateCppHeader));
}
LibraryCompilationResult librarySupportCompile(LibrarySupport support,
MlirStringRef module,
CompilationOptions options) {
std::string moduleStr(module.data, module.length);
auto retOrError = unwrap(support)->compile(moduleStr, *unwrap(options));
if (!retOrError) {
llvm::errs() << llvm::toString(retOrError.takeError());
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL);
}
return wrap(new mlir::concretelang::LibraryCompilationResult(
*retOrError.get().get()));
}