mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(execution): run on both int and tensor args
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
|
||||
using namespace zamalang;
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
using zamalang::python::ExecutionArgument;
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
@@ -41,19 +42,62 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
return os.str();
|
||||
});
|
||||
|
||||
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
|
||||
m, "ExecutionArgument")
|
||||
.def("create",
|
||||
pybind11::overload_cast<uint64_t>(&ExecutionArgument::create))
|
||||
.def("create", pybind11::overload_cast<std::vector<uint8_t>>(
|
||||
&ExecutionArgument::create))
|
||||
.def("is_tensor", &ExecutionArgument::isTensor)
|
||||
.def("is_int", &ExecutionArgument::isInt);
|
||||
|
||||
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def("run",
|
||||
[](CompilerEngine &engine, std::vector<uint64_t> args) {
|
||||
auto result = engine.run(args);
|
||||
if (!result) { // not an error
|
||||
llvm::errs()
|
||||
<< "Execution failed: " << result.takeError() << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed running, see previous logs for more info");
|
||||
}
|
||||
return result.get();
|
||||
})
|
||||
.def(
|
||||
"run",
|
||||
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
if (auto err = maybeArgument.takeError()) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed building arguments, see previous logs for more info");
|
||||
}
|
||||
// Set the integer/tensor arguments
|
||||
auto arguments = std::move(maybeArgument.get());
|
||||
for (auto i = 0; i < args.size(); i++) {
|
||||
if (args[i].isInt()) { // integer argument
|
||||
if (auto err =
|
||||
arguments->setArg(i, args[i].getIntegerArgument())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed pushing integer argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
} else { // tensor argument
|
||||
assert(args[i].isTensor() && "should be tensor argument");
|
||||
if (auto err = arguments->setArg(i, args[i].getTensorArgument(),
|
||||
args[i].getTensorSize())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed pushing tensor argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = engine.invoke(*arguments)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed running, see previous logs for more info");
|
||||
}
|
||||
uint64_t result = 0;
|
||||
if (auto err = arguments->getResult(0, result)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed getting result, see previous logs for more info");
|
||||
}
|
||||
return result;
|
||||
})
|
||||
.def("compile_fhe",
|
||||
[](CompilerEngine &engine, std::string mlir_input) {
|
||||
auto error = engine.compile(mlir_input);
|
||||
|
||||
Reference in New Issue
Block a user