From 967fda07a05b6a410fee2027514a7114bdf781e9 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 9 Sep 2021 16:02:17 +0100 Subject: [PATCH] feat(execution): run on both int and tensor args --- compiler/python/CompilerAPIModule.cpp | 66 +++++++++++++++---- compiler/python/CompilerAPIModule.h | 36 ++++++++++ compiler/python/zamalang/compiler.py | 41 ++++++++++-- compiler/tests/python/test_compiler_engine.py | 12 ++++ 4 files changed, 138 insertions(+), 17 deletions(-) diff --git a/compiler/python/CompilerAPIModule.cpp b/compiler/python/CompilerAPIModule.cpp index b64e04040..887e79fe0 100644 --- a/compiler/python/CompilerAPIModule.cpp +++ b/compiler/python/CompilerAPIModule.cpp @@ -21,6 +21,7 @@ using namespace zamalang; using mlir::zamalang::CompilerEngine; +using zamalang::python::ExecutionArgument; /// Populate the compiler API python module. void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { @@ -41,19 +42,62 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { return os.str(); }); + pybind11::class_>( + m, "ExecutionArgument") + .def("create", + pybind11::overload_cast(&ExecutionArgument::create)) + .def("create", pybind11::overload_cast>( + &ExecutionArgument::create)) + .def("is_tensor", &ExecutionArgument::isTensor) + .def("is_int", &ExecutionArgument::isInt); + pybind11::class_(m, "CompilerEngine") .def(pybind11::init()) - .def("run", - [](CompilerEngine &engine, std::vector args) { - auto result = engine.run(args); - if (!result) { // not an error - llvm::errs() - << "Execution failed: " << result.takeError() << "\n"; - throw std::runtime_error( - "failed running, see previous logs for more info"); - } - return result.get(); - }) + .def( + "run", + [](CompilerEngine &engine, std::vector args) { + auto maybeArgument = engine.buildArgument(); + if (auto err = maybeArgument.takeError()) { + llvm::errs() << "Execution failed: " << err << "\n"; + throw std::runtime_error( + "failed building arguments, see previous logs for more info"); + } + // Set the integer/tensor arguments + auto arguments = std::move(maybeArgument.get()); + for (auto i = 0; i < args.size(); i++) { + if (args[i].isInt()) { // integer argument + if (auto err = + arguments->setArg(i, args[i].getIntegerArgument())) { + llvm::errs() << "Execution failed: " << err << "\n"; + throw std::runtime_error( + "failed pushing integer argument, see " + "previous logs for more info"); + } + } else { // tensor argument + assert(args[i].isTensor() && "should be tensor argument"); + if (auto err = arguments->setArg(i, args[i].getTensorArgument(), + args[i].getTensorSize())) { + llvm::errs() << "Execution failed: " << err << "\n"; + throw std::runtime_error( + "failed pushing tensor argument, see " + "previous logs for more info"); + } + } + } + // Invoke the lambda + if (auto err = engine.invoke(*arguments)) { + llvm::errs() << "Execution failed: " << err << "\n"; + throw std::runtime_error( + "failed running, see previous logs for more info"); + } + uint64_t result = 0; + if (auto err = arguments->getResult(0, result)) { + llvm::errs() << "Execution failed: " << err << "\n"; + throw std::runtime_error( + "failed getting result, see previous logs for more info"); + } + return result; + }) .def("compile_fhe", [](CompilerEngine &engine, std::string mlir_input) { auto error = engine.compile(mlir_input); diff --git a/compiler/python/CompilerAPIModule.h b/compiler/python/CompilerAPIModule.h index 5989938d0..a95c5417e 100644 --- a/compiler/python/CompilerAPIModule.h +++ b/compiler/python/CompilerAPIModule.h @@ -6,6 +6,42 @@ namespace zamalang { namespace python { +// Frontend object to abstract the different types of possible arguments, +// namely, integers, and tensors. +class ExecutionArgument { +public: + // There are two possible underlying types for the execution argument, either + // and int, or a tensor + bool isTensor() { return isTensorArg; } + bool isInt() { return !isTensorArg; } + + uint8_t *getTensorArgument() { return tensorArg.data(); } + + size_t getTensorSize() { return tensorArg.size(); } + + uint64_t getIntegerArgument() { return intArg; } + + // Create an execution argument from an integer + static std::shared_ptr create(uint64_t arg) { + return std::shared_ptr(new ExecutionArgument(arg)); + } + // Create an execution argument from a tensor + static std::shared_ptr create(std::vector arg) { + return std::shared_ptr(new ExecutionArgument(arg)); + } + +private: + ExecutionArgument(int arg) + : isTensorArg(false), intArg(arg) {} + + ExecutionArgument(std::vector tensor) + : isTensorArg(true), tensorArg(tensor) {} + + uint64_t intArg; + std::vector tensorArg; + bool isTensorArg; +}; + void populateCompilerAPISubmodule(pybind11::module &m); } // namespace python diff --git a/compiler/python/zamalang/compiler.py b/compiler/python/zamalang/compiler.py index 5e6e9ab2e..185e4a169 100644 --- a/compiler/python/zamalang/compiler.py +++ b/compiler/python/zamalang/compiler.py @@ -1,6 +1,7 @@ """Compiler submodule""" -from typing import List +from typing import List, Union from _zamalang._compiler import CompilerEngine as _CompilerEngine +from _zamalang._compiler import ExecutionArgument as _ExecutionArgument from _zamalang._compiler import round_trip as _round_trip @@ -20,6 +21,32 @@ def round_trip(mlir_str: str) -> str: raise TypeError("input must be an `str`") return _round_trip(mlir_str) + +def create_execution_argument(value: Union[int, List[int]]) -> "_ExecutionArgument": + """Create an execution argument holding either an int or tensor value. + + Args: + value (Union[int, List[int]]): value of the argument, either an int, or a list of int + + Raises: + TypeError: if the values aren't in the expected range, or using a wrong type + + Returns: + _ExecutionArgument: execution argument holding the appropriate value + """ + if not isinstance(value, (int, list)): + raise TypeError("value of execution argument must be either int or list[int]") + if isinstance(value, int): + if not (0 <= value < (2 ** 64 - 1)): + raise TypeError("single integer must be in the range [0, 2**64 - 1] (uint64)") + else: + assert isinstance(value, list) + for elem in value: + if not (0 <= elem < (2 ** 8 - 1)): + raise TypeError("values of the list must be in the range [0, 255] (uint8)") + return _ExecutionArgument.create(value) + + class CompilerEngine: def __init__(self, mlir_str: str = None): self._engine = _CompilerEngine() @@ -42,18 +69,20 @@ class CompilerEngine: raise TypeError("input must be an `str`") return self._engine.compile_fhe(mlir_str) - def run(self, *args: List[int]) -> int: + def run(self, *args: List[Union[int, List[int]]]) -> int: """Run the compiled code. + Args: + *args: list of arguments for execution. Each argument can be an int, or a list of int + Raises: - TypeError: if arguments aren't of type int + TypeError: if execution arguments can't be constructed Returns: int: result of execution. """ - if not all(isinstance(arg, int) for arg in args): - raise TypeError("arguments must be of type int") - return self._engine.run(args) + execution_arguments = [create_execution_argument(arg) for arg in args] + return self._engine.run(execution_arguments) def get_compiled_module(self) -> str: """Compiled module in printable form. diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 52ba1b074..5f8eb8930 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -15,6 +15,18 @@ from zamalang import CompilerEngine (5, 7), 12, ), + ( + """ + func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7> + { + %ret = "HLFHE.dot_eint_int"(%arg0, %arg1) : + (tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7> + return %ret : !HLFHE.eint<7> + } + """, + ([1, 2, 3, 4], [4, 3, 2, 1]), + 20, + ), ], ) def test_compile_and_run(mlir_input, args, expected_result):