mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(execution): run on both int and tensor args
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
|
||||
using namespace zamalang;
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
using zamalang::python::ExecutionArgument;
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
@@ -41,19 +42,62 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
return os.str();
|
||||
});
|
||||
|
||||
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
|
||||
m, "ExecutionArgument")
|
||||
.def("create",
|
||||
pybind11::overload_cast<uint64_t>(&ExecutionArgument::create))
|
||||
.def("create", pybind11::overload_cast<std::vector<uint8_t>>(
|
||||
&ExecutionArgument::create))
|
||||
.def("is_tensor", &ExecutionArgument::isTensor)
|
||||
.def("is_int", &ExecutionArgument::isInt);
|
||||
|
||||
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def("run",
|
||||
[](CompilerEngine &engine, std::vector<uint64_t> args) {
|
||||
auto result = engine.run(args);
|
||||
if (!result) { // not an error
|
||||
llvm::errs()
|
||||
<< "Execution failed: " << result.takeError() << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed running, see previous logs for more info");
|
||||
}
|
||||
return result.get();
|
||||
})
|
||||
.def(
|
||||
"run",
|
||||
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
if (auto err = maybeArgument.takeError()) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed building arguments, see previous logs for more info");
|
||||
}
|
||||
// Set the integer/tensor arguments
|
||||
auto arguments = std::move(maybeArgument.get());
|
||||
for (auto i = 0; i < args.size(); i++) {
|
||||
if (args[i].isInt()) { // integer argument
|
||||
if (auto err =
|
||||
arguments->setArg(i, args[i].getIntegerArgument())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed pushing integer argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
} else { // tensor argument
|
||||
assert(args[i].isTensor() && "should be tensor argument");
|
||||
if (auto err = arguments->setArg(i, args[i].getTensorArgument(),
|
||||
args[i].getTensorSize())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed pushing tensor argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = engine.invoke(*arguments)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed running, see previous logs for more info");
|
||||
}
|
||||
uint64_t result = 0;
|
||||
if (auto err = arguments->getResult(0, result)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed getting result, see previous logs for more info");
|
||||
}
|
||||
return result;
|
||||
})
|
||||
.def("compile_fhe",
|
||||
[](CompilerEngine &engine, std::string mlir_input) {
|
||||
auto error = engine.compile(mlir_input);
|
||||
|
||||
@@ -6,6 +6,42 @@
|
||||
namespace zamalang {
|
||||
namespace python {
|
||||
|
||||
// Frontend object to abstract the different types of possible arguments,
|
||||
// namely, integers, and tensors.
|
||||
class ExecutionArgument {
|
||||
public:
|
||||
// There are two possible underlying types for the execution argument, either
|
||||
// and int, or a tensor
|
||||
bool isTensor() { return isTensorArg; }
|
||||
bool isInt() { return !isTensorArg; }
|
||||
|
||||
uint8_t *getTensorArgument() { return tensorArg.data(); }
|
||||
|
||||
size_t getTensorSize() { return tensorArg.size(); }
|
||||
|
||||
uint64_t getIntegerArgument() { return intArg; }
|
||||
|
||||
// Create an execution argument from an integer
|
||||
static std::shared_ptr<ExecutionArgument> create(uint64_t arg) {
|
||||
return std::shared_ptr<ExecutionArgument>(new ExecutionArgument(arg));
|
||||
}
|
||||
// Create an execution argument from a tensor
|
||||
static std::shared_ptr<ExecutionArgument> create(std::vector<uint8_t> arg) {
|
||||
return std::shared_ptr<ExecutionArgument>(new ExecutionArgument(arg));
|
||||
}
|
||||
|
||||
private:
|
||||
ExecutionArgument(int arg)
|
||||
: isTensorArg(false), intArg(arg) {}
|
||||
|
||||
ExecutionArgument(std::vector<uint8_t> tensor)
|
||||
: isTensorArg(true), tensorArg(tensor) {}
|
||||
|
||||
uint64_t intArg;
|
||||
std::vector<uint8_t> tensorArg;
|
||||
bool isTensorArg;
|
||||
};
|
||||
|
||||
void populateCompilerAPISubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Compiler submodule"""
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
from _zamalang._compiler import CompilerEngine as _CompilerEngine
|
||||
from _zamalang._compiler import ExecutionArgument as _ExecutionArgument
|
||||
from _zamalang._compiler import round_trip as _round_trip
|
||||
|
||||
|
||||
@@ -20,6 +21,32 @@ def round_trip(mlir_str: str) -> str:
|
||||
raise TypeError("input must be an `str`")
|
||||
return _round_trip(mlir_str)
|
||||
|
||||
|
||||
def create_execution_argument(value: Union[int, List[int]]) -> "_ExecutionArgument":
|
||||
"""Create an execution argument holding either an int or tensor value.
|
||||
|
||||
Args:
|
||||
value (Union[int, List[int]]): value of the argument, either an int, or a list of int
|
||||
|
||||
Raises:
|
||||
TypeError: if the values aren't in the expected range, or using a wrong type
|
||||
|
||||
Returns:
|
||||
_ExecutionArgument: execution argument holding the appropriate value
|
||||
"""
|
||||
if not isinstance(value, (int, list)):
|
||||
raise TypeError("value of execution argument must be either int or list[int]")
|
||||
if isinstance(value, int):
|
||||
if not (0 <= value < (2 ** 64 - 1)):
|
||||
raise TypeError("single integer must be in the range [0, 2**64 - 1] (uint64)")
|
||||
else:
|
||||
assert isinstance(value, list)
|
||||
for elem in value:
|
||||
if not (0 <= elem < (2 ** 8 - 1)):
|
||||
raise TypeError("values of the list must be in the range [0, 255] (uint8)")
|
||||
return _ExecutionArgument.create(value)
|
||||
|
||||
|
||||
class CompilerEngine:
|
||||
def __init__(self, mlir_str: str = None):
|
||||
self._engine = _CompilerEngine()
|
||||
@@ -42,18 +69,20 @@ class CompilerEngine:
|
||||
raise TypeError("input must be an `str`")
|
||||
return self._engine.compile_fhe(mlir_str)
|
||||
|
||||
def run(self, *args: List[int]) -> int:
|
||||
def run(self, *args: List[Union[int, List[int]]]) -> int:
|
||||
"""Run the compiled code.
|
||||
|
||||
Args:
|
||||
*args: list of arguments for execution. Each argument can be an int, or a list of int
|
||||
|
||||
Raises:
|
||||
TypeError: if arguments aren't of type int
|
||||
TypeError: if execution arguments can't be constructed
|
||||
|
||||
Returns:
|
||||
int: result of execution.
|
||||
"""
|
||||
if not all(isinstance(arg, int) for arg in args):
|
||||
raise TypeError("arguments must be of type int")
|
||||
return self._engine.run(args)
|
||||
execution_arguments = [create_execution_argument(arg) for arg in args]
|
||||
return self._engine.run(execution_arguments)
|
||||
|
||||
def get_compiled_module(self) -> str:
|
||||
"""Compiled module in printable form.
|
||||
|
||||
@@ -15,6 +15,18 @@ from zamalang import CompilerEngine
|
||||
(5, 7),
|
||||
12,
|
||||
),
|
||||
(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
{
|
||||
%ret = "HLFHE.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
return %ret : !HLFHE.eint<7>
|
||||
}
|
||||
""",
|
||||
([1, 2, 3, 4], [4, 3, 2, 1]),
|
||||
20,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run(mlir_input, args, expected_result):
|
||||
|
||||
Reference in New Issue
Block a user