mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
This refactoring commit restructures the compilation pipeline of
`zamacompiler`, such that it is possible to enter and exit the
pipeline at different points, effectively defining the level of
abstraction at the input and the required level of abstraction for the
output.
The entry point is specified using the `--entry-dialect`
argument. Valid choices are:
`--entry-dialect=hlfhe`: Source contains HLFHE operations
`--entry-dialect=midlfhe`: Source contains MidLFHE operations
`--entry-dialect=lowlfhe`: Source contains LowLFHE operations
`--entry-dialect=std`: Source does not contain any FHE Operations
`--entry-dialect=llvm`: Source is in LLVM dialect
The exit point is defined by an action, specified using --action.
`--action=roundtrip`:
Parse the source file to in-memory representation and immediately
dump as text without any processing
`--action=dump-midlfhe`:
Lower source to MidLFHE and dump result as text
`--action=dump-lowlfhe`:
Lower source to LowLFHE and dump result as text
`--action=dump-std`:
Lower source to only standard MLIR dialects (i.e., all FHE
operations have already been lowered)
`--action=dump-llvm-dialect`:
Lower source to MLIR's LLVM dialect (i.e., the LLVM dialect, not
LLVM IR)
`--action=dump-llvm-ir`:
Lower source to plain LLVM IR (i.e., not the LLVM dialect, but
actual LLVM IR)
`--action=dump-optimized-llvm-ir`:
Lower source to plain LLVM IR (i.e., not the LLVM dialect, but
actual LLVM IR), pass the result through the LLVM optimizer and
print the result.
`--action=dump-jit-invoke`:
Execute the full lowering pipeline to optimized LLVM IR, JIT
compile the result, invoke the function specified in
`--jit-funcname` with the parameters from `--jit-args` and print
the functions return value.
111 lines
4.6 KiB
C++
111 lines
4.6 KiB
C++
#include "CompilerAPIModule.h"
|
|
#include "zamalang/Conversion/Passes.h"
|
|
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
|
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
|
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
|
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
|
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
|
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
|
#include "zamalang/Support/CompilerEngine.h"
|
|
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
|
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
|
#include <mlir/ExecutionEngine/OptUtils.h>
|
|
#include <mlir/Parser.h>
|
|
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/pytypes.h>
|
|
#include <pybind11/stl.h>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
|
|
using namespace zamalang;
|
|
using mlir::zamalang::CompilerEngine;
|
|
using zamalang::python::ExecutionArgument;
|
|
|
|
/// Populate the compiler API python module.
|
|
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
|
m.doc() = "Zamalang compiler python API";
|
|
|
|
m.def("round_trip", [](std::string mlir_input) {
|
|
mlir::MLIRContext context;
|
|
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
|
context.getOrLoadDialect<mlir::StandardOpsDialect>();
|
|
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
|
auto module_ref = mlir::parseSourceString(mlir_input, &context);
|
|
if (!module_ref) {
|
|
throw std::logic_error("mlir parsing failed");
|
|
}
|
|
std::string result;
|
|
llvm::raw_string_ostream os(result);
|
|
module_ref->print(os);
|
|
return os.str();
|
|
});
|
|
|
|
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
|
|
m, "ExecutionArgument")
|
|
.def("create",
|
|
pybind11::overload_cast<uint64_t>(&ExecutionArgument::create))
|
|
.def("create", pybind11::overload_cast<std::vector<uint8_t>>(
|
|
&ExecutionArgument::create))
|
|
.def("is_tensor", &ExecutionArgument::isTensor)
|
|
.def("is_int", &ExecutionArgument::isInt);
|
|
|
|
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
|
|
.def(pybind11::init())
|
|
.def(
|
|
"run",
|
|
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
|
|
auto maybeArgument = engine.buildArgument();
|
|
if (auto err = maybeArgument.takeError()) {
|
|
llvm::errs() << "Execution failed: " << err << "\n";
|
|
throw std::runtime_error(
|
|
"failed building arguments, see previous logs for more info");
|
|
}
|
|
// Set the integer/tensor arguments
|
|
auto arguments = std::move(maybeArgument.get());
|
|
for (auto i = 0; i < args.size(); i++) {
|
|
if (args[i].isInt()) { // integer argument
|
|
if (auto err =
|
|
arguments->setArg(i, args[i].getIntegerArgument())) {
|
|
llvm::errs() << "Execution failed: " << err << "\n";
|
|
throw std::runtime_error(
|
|
"failed pushing integer argument, see "
|
|
"previous logs for more info");
|
|
}
|
|
} else { // tensor argument
|
|
assert(args[i].isTensor() && "should be tensor argument");
|
|
if (auto err = arguments->setArg(i, args[i].getTensorArgument(),
|
|
args[i].getTensorSize())) {
|
|
llvm::errs() << "Execution failed: " << err << "\n";
|
|
throw std::runtime_error(
|
|
"failed pushing tensor argument, see "
|
|
"previous logs for more info");
|
|
}
|
|
}
|
|
}
|
|
// Invoke the lambda
|
|
if (auto err = engine.invoke(*arguments)) {
|
|
llvm::errs() << "Execution failed: " << err << "\n";
|
|
throw std::runtime_error(
|
|
"failed running, see previous logs for more info");
|
|
}
|
|
uint64_t result = 0;
|
|
if (auto err = arguments->getResult(0, result)) {
|
|
llvm::errs() << "Execution failed: " << err << "\n";
|
|
throw std::runtime_error(
|
|
"failed getting result, see previous logs for more info");
|
|
}
|
|
return result;
|
|
})
|
|
.def("compile_fhe",
|
|
[](CompilerEngine &engine, std::string mlir_input) {
|
|
auto error = engine.compile(mlir_input);
|
|
if (error) {
|
|
llvm::errs() << "Compilation failed: " << error << "\n";
|
|
throw std::runtime_error(
|
|
"failed compiling, see previous logs for more info");
|
|
}
|
|
})
|
|
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
|
|
}
|