feat(execution): run on both int and tensor args

This commit is contained in:
youben11
2021-09-09 16:02:17 +01:00
committed by Ayoub Benaissa
parent c37ac41c1a
commit 967fda07a0
4 changed files with 138 additions and 17 deletions

View File

@@ -21,6 +21,7 @@
using namespace zamalang;
using mlir::zamalang::CompilerEngine;
using zamalang::python::ExecutionArgument;
/// Populate the compiler API python module.
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
@@ -41,19 +42,62 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
return os.str();
});
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
m, "ExecutionArgument")
.def("create",
pybind11::overload_cast<uint64_t>(&ExecutionArgument::create))
.def("create", pybind11::overload_cast<std::vector<uint8_t>>(
&ExecutionArgument::create))
.def("is_tensor", &ExecutionArgument::isTensor)
.def("is_int", &ExecutionArgument::isInt);
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
.def(pybind11::init())
.def("run",
[](CompilerEngine &engine, std::vector<uint64_t> args) {
auto result = engine.run(args);
if (!result) { // not an error
llvm::errs()
<< "Execution failed: " << result.takeError() << "\n";
throw std::runtime_error(
"failed running, see previous logs for more info");
}
return result.get();
})
.def(
"run",
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
auto maybeArgument = engine.buildArgument();
if (auto err = maybeArgument.takeError()) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed building arguments, see previous logs for more info");
}
// Set the integer/tensor arguments
auto arguments = std::move(maybeArgument.get());
for (auto i = 0; i < args.size(); i++) {
if (args[i].isInt()) { // integer argument
if (auto err =
arguments->setArg(i, args[i].getIntegerArgument())) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed pushing integer argument, see "
"previous logs for more info");
}
} else { // tensor argument
assert(args[i].isTensor() && "should be tensor argument");
if (auto err = arguments->setArg(i, args[i].getTensorArgument(),
args[i].getTensorSize())) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed pushing tensor argument, see "
"previous logs for more info");
}
}
}
// Invoke the lambda
if (auto err = engine.invoke(*arguments)) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed running, see previous logs for more info");
}
uint64_t result = 0;
if (auto err = arguments->getResult(0, result)) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed getting result, see previous logs for more info");
}
return result;
})
.def("compile_fhe",
[](CompilerEngine &engine, std::string mlir_input) {
auto error = engine.compile(mlir_input);