diff --git a/compiler/include/zamalang-c/Support/CompilerEngine.h b/compiler/include/zamalang-c/Support/CompilerEngine.h index 834b30c50..4ec8265db 100644 --- a/compiler/include/zamalang-c/Support/CompilerEngine.h +++ b/compiler/include/zamalang-c/Support/CompilerEngine.h @@ -23,13 +23,27 @@ struct executionArguments { }; typedef struct executionArguments exectuionArguments; +struct lambdaArgument { + std::unique_ptr ptr; +}; +typedef struct lambdaArgument lambdaArgument; + MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module, const char *funcName); -MLIR_CAPI_EXPORTED uint64_t invokeLambda(lambda l, executionArguments args); +MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l, + executionArguments args); MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); +MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg); +MLIR_CAPI_EXPORTED std::vector +lambdaArgumentGetTensorData(lambdaArgument &lambda_arg); +MLIR_CAPI_EXPORTED std::vector +lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg); +MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg); +MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg); + #ifdef __cplusplus } #endif diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 65e100e68..7c9bb7c99 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -17,6 +17,7 @@ using mlir::zamalang::ExecutionArgument; using mlir::zamalang::JitCompilerEngine; +using mlir::zamalang::LambdaArgument; /// Populate the compiler API python module. void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { @@ -41,6 +42,27 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { return buildLambda(mlir_input.c_str(), func_name.c_str()); }); + pybind11::class_(m, "LambdaArgument") + .def("is_tensor", + [](lambdaArgument &lambda_arg) { + return lambdaArgumentIsTensor(lambda_arg); + }) + .def("get_tensor_data", + [](lambdaArgument &lambda_arg) { + return lambdaArgumentGetTensorData(lambda_arg); + }) + .def("get_tensor_shape", + [](lambdaArgument &lambda_arg) { + return lambdaArgumentGetTensorDimensions(lambda_arg); + }) + .def("is_scalar", + [](lambdaArgument &lambda_arg) { + return lambdaArgumentIsScalar(lambda_arg); + }) + .def("get_scalar", [](lambdaArgument &lambda_arg) { + return lambdaArgumentGetScalar(lambda_arg); + }); + pybind11::class_(m, "Lambda") .def("invoke", [](JitCompilerEngine::Lambda &py_lambda, std::vector args) { diff --git a/compiler/lib/Bindings/Python/zamalang/compiler.py b/compiler/lib/Bindings/Python/zamalang/compiler.py index 130f4275e..c334b24db 100644 --- a/compiler/lib/Bindings/Python/zamalang/compiler.py +++ b/compiler/lib/Bindings/Python/zamalang/compiler.py @@ -3,6 +3,8 @@ from typing import List, Union from mlir._mlir_libs._zamalang._compiler import JitCompilerEngine as _JitCompilerEngine from mlir._mlir_libs._zamalang._compiler import ExecutionArgument as _ExecutionArgument from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip +import numpy as np + def round_trip(mlir_str: str) -> str: """Parse the MLIR input, then return it back. @@ -67,7 +69,7 @@ class CompilerEngine: raise TypeError("input must be an `str`") self._lambda = self._engine.build_lambda(mlir_str, func_name) - def run(self, *args: List[Union[int, List[int]]]) -> int: + def run(self, *args: List[Union[int, List[int]]]) -> Union[int, np.array]: """Run the compiled code. Args: @@ -76,11 +78,20 @@ class CompilerEngine: Raises: TypeError: if execution arguments can't be constructed RuntimeError: if the engine has not compiled any code yet + RuntimeError: if the return type is unknown Returns: - int: result of execution. + int or numpy.array: result of execution. """ if self._lambda is None: raise RuntimeError("need to compile an MLIR code first") execution_arguments = [create_execution_argument(arg) for arg in args] - return self._lambda.invoke(execution_arguments) + lambda_arg = self._lambda.invoke(execution_arguments) + if lambda_arg.is_scalar(): + return lambda_arg.get_scalar() + elif lambda_arg.is_tensor(): + shape = lambda_arg.get_tensor_shape() + tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape) + return tensor + else: + raise RuntimeError("unknown return type") diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 5426f5489..851ff15c7 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -23,7 +23,7 @@ mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module, return std::move(*lambdaOrErr); } -uint64_t invokeLambda(lambda l, executionArguments args) { +lambdaArgument invokeLambda(lambda l, executionArguments args) { mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr = (mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr; @@ -45,8 +45,12 @@ uint64_t invokeLambda(lambda l, executionArguments args) { } } // Run lambda - llvm::Expected resOrError = (*lambda_ptr)( - llvm::ArrayRef(lambdaArgumentsRef)); + llvm::Expected> resOrError = + (*lambda_ptr) + . + operator()>( + llvm::ArrayRef( + lambdaArgumentsRef)); // Free heap for (size_t i = 0; i < lambdaArgumentsRef.size(); i++) delete lambdaArgumentsRef[i]; @@ -58,7 +62,8 @@ uint64_t invokeLambda(lambda l, executionArguments args) { << llvm::toString(std::move(resOrError.takeError())); throw std::runtime_error(os.str()); } - return *resOrError; + lambdaArgument result{std::move(*resOrError)}; + return result; } std::string roundTrip(const char *module) { @@ -80,3 +85,59 @@ std::string roundTrip(const char *module) { retOrErr->mlirModuleRef->get().print(os); return os.str(); } + +bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) { + return lambda_arg.ptr->isa>>(); +} + +std::vector lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) { + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> *arg = + lambda_arg.ptr->dyn_cast>>(); + if (arg == nullptr) { + throw std::invalid_argument( + "LambdaArgument isn't a tensor, should " + "be a TensorLambdaArgument>"); + } + + llvm::Expected sizeOrErr = arg->getNumElements(); + if (!sizeOrErr) { + std::string backingString; + llvm::raw_string_ostream os(backingString); + os << "Couldn't get size of tensor: " + << llvm::toString(std::move(sizeOrErr.takeError())); + throw std::runtime_error(os.str()); + } + std::vector data(arg->getValue(), arg->getValue() + *sizeOrErr); + return data; +} + +std::vector +lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) { + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> *arg = + lambda_arg.ptr->dyn_cast>>(); + if (arg == nullptr) { + throw std::invalid_argument( + "LambdaArgument isn't a tensor, should " + "be a TensorLambdaArgument>"); + } + return arg->getDimensions(); +} + +bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) { + return lambda_arg.ptr->isa>(); +} + +uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) { + mlir::zamalang::IntLambdaArgument *arg = + lambda_arg.ptr->dyn_cast>(); + if (arg == nullptr) { + throw std::invalid_argument("LambdaArgument isn't a scalar, should " + "be an IntLambdaArgument"); + } + return arg->getValue(); +} diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 40600289a..ee8d917b3 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -29,12 +29,28 @@ from zamalang import CompilerEngine 20, id="dot_eint_int" ), + pytest.param( + """ + func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> { + %res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> + return %res : tensor<4x!HLFHE.eint<6>> + } + """, + ([31, 6, 12, 9], [32, 9, 2, 3]), + [63, 15, 14, 12], + id="add_eint_int_1D" + ), ], ) def test_compile_and_run(mlir_input, args, expected_result): engine = CompilerEngine() engine.compile_fhe(mlir_input) - assert engine.run(*args) == expected_result + if isinstance(expected_result, int): + assert engine.run(*args) == expected_result + else: + # numpy array on the left + assert (engine.run(*args) == expected_result).all() +