feat(execution): run on both int and tensor args

This commit is contained in:
youben11
2021-09-09 16:02:17 +01:00
committed by Ayoub Benaissa
parent c37ac41c1a
commit 967fda07a0
4 changed files with 138 additions and 17 deletions

View File

@@ -21,6 +21,7 @@
using namespace zamalang;
using mlir::zamalang::CompilerEngine;
using zamalang::python::ExecutionArgument;
/// Populate the compiler API python module.
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
@@ -41,19 +42,62 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
return os.str();
});
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
m, "ExecutionArgument")
.def("create",
pybind11::overload_cast<uint64_t>(&ExecutionArgument::create))
.def("create", pybind11::overload_cast<std::vector<uint8_t>>(
&ExecutionArgument::create))
.def("is_tensor", &ExecutionArgument::isTensor)
.def("is_int", &ExecutionArgument::isInt);
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
.def(pybind11::init())
.def("run",
[](CompilerEngine &engine, std::vector<uint64_t> args) {
auto result = engine.run(args);
if (!result) { // not an error
llvm::errs()
<< "Execution failed: " << result.takeError() << "\n";
throw std::runtime_error(
"failed running, see previous logs for more info");
}
return result.get();
})
.def(
"run",
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
auto maybeArgument = engine.buildArgument();
if (auto err = maybeArgument.takeError()) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed building arguments, see previous logs for more info");
}
// Set the integer/tensor arguments
auto arguments = std::move(maybeArgument.get());
for (auto i = 0; i < args.size(); i++) {
if (args[i].isInt()) { // integer argument
if (auto err =
arguments->setArg(i, args[i].getIntegerArgument())) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed pushing integer argument, see "
"previous logs for more info");
}
} else { // tensor argument
assert(args[i].isTensor() && "should be tensor argument");
if (auto err = arguments->setArg(i, args[i].getTensorArgument(),
args[i].getTensorSize())) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed pushing tensor argument, see "
"previous logs for more info");
}
}
}
// Invoke the lambda
if (auto err = engine.invoke(*arguments)) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed running, see previous logs for more info");
}
uint64_t result = 0;
if (auto err = arguments->getResult(0, result)) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed getting result, see previous logs for more info");
}
return result;
})
.def("compile_fhe",
[](CompilerEngine &engine, std::string mlir_input) {
auto error = engine.compile(mlir_input);

View File

@@ -6,6 +6,42 @@
namespace zamalang {
namespace python {
// Frontend object to abstract the different types of possible arguments,
// namely, integers, and tensors.
class ExecutionArgument {
public:
// There are two possible underlying types for the execution argument, either
// and int, or a tensor
bool isTensor() { return isTensorArg; }
bool isInt() { return !isTensorArg; }
uint8_t *getTensorArgument() { return tensorArg.data(); }
size_t getTensorSize() { return tensorArg.size(); }
uint64_t getIntegerArgument() { return intArg; }
// Create an execution argument from an integer
static std::shared_ptr<ExecutionArgument> create(uint64_t arg) {
return std::shared_ptr<ExecutionArgument>(new ExecutionArgument(arg));
}
// Create an execution argument from a tensor
static std::shared_ptr<ExecutionArgument> create(std::vector<uint8_t> arg) {
return std::shared_ptr<ExecutionArgument>(new ExecutionArgument(arg));
}
private:
ExecutionArgument(int arg)
: isTensorArg(false), intArg(arg) {}
ExecutionArgument(std::vector<uint8_t> tensor)
: isTensorArg(true), tensorArg(tensor) {}
uint64_t intArg;
std::vector<uint8_t> tensorArg;
bool isTensorArg;
};
void populateCompilerAPISubmodule(pybind11::module &m);
} // namespace python

View File

@@ -1,6 +1,7 @@
"""Compiler submodule"""
from typing import List
from typing import List, Union
from _zamalang._compiler import CompilerEngine as _CompilerEngine
from _zamalang._compiler import ExecutionArgument as _ExecutionArgument
from _zamalang._compiler import round_trip as _round_trip
@@ -20,6 +21,32 @@ def round_trip(mlir_str: str) -> str:
raise TypeError("input must be an `str`")
return _round_trip(mlir_str)
def create_execution_argument(value: Union[int, List[int]]) -> "_ExecutionArgument":
"""Create an execution argument holding either an int or tensor value.
Args:
value (Union[int, List[int]]): value of the argument, either an int, or a list of int
Raises:
TypeError: if the values aren't in the expected range, or using a wrong type
Returns:
_ExecutionArgument: execution argument holding the appropriate value
"""
if not isinstance(value, (int, list)):
raise TypeError("value of execution argument must be either int or list[int]")
if isinstance(value, int):
if not (0 <= value < (2 ** 64 - 1)):
raise TypeError("single integer must be in the range [0, 2**64 - 1] (uint64)")
else:
assert isinstance(value, list)
for elem in value:
if not (0 <= elem < (2 ** 8 - 1)):
raise TypeError("values of the list must be in the range [0, 255] (uint8)")
return _ExecutionArgument.create(value)
class CompilerEngine:
def __init__(self, mlir_str: str = None):
self._engine = _CompilerEngine()
@@ -42,18 +69,20 @@ class CompilerEngine:
raise TypeError("input must be an `str`")
return self._engine.compile_fhe(mlir_str)
def run(self, *args: List[int]) -> int:
def run(self, *args: List[Union[int, List[int]]]) -> int:
"""Run the compiled code.
Args:
*args: list of arguments for execution. Each argument can be an int, or a list of int
Raises:
TypeError: if arguments aren't of type int
TypeError: if execution arguments can't be constructed
Returns:
int: result of execution.
"""
if not all(isinstance(arg, int) for arg in args):
raise TypeError("arguments must be of type int")
return self._engine.run(args)
execution_arguments = [create_execution_argument(arg) for arg in args]
return self._engine.run(execution_arguments)
def get_compiled_module(self) -> str:
"""Compiled module in printable form.