feat(python): support functions returning tensors

This commit is contained in:
youben11
2021-11-04 19:06:11 +01:00
committed by Ayoub Benaissa
parent badc8e44bf
commit b501e3d6c0
5 changed files with 133 additions and 9 deletions

View File

@@ -23,13 +23,27 @@ struct executionArguments {
};
typedef struct executionArguments exectuionArguments;
struct lambdaArgument {
std::unique_ptr<mlir::zamalang::LambdaArgument> ptr;
};
typedef struct lambdaArgument lambdaArgument;
MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName);
MLIR_CAPI_EXPORTED uint64_t invokeLambda(lambda l, executionArguments args);
MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l,
executionArguments args);
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg);
MLIR_CAPI_EXPORTED std::vector<uint64_t>
lambdaArgumentGetTensorData(lambdaArgument &lambda_arg);
MLIR_CAPI_EXPORTED std::vector<int64_t>
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg);
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg);
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg);
#ifdef __cplusplus
}
#endif

View File

@@ -17,6 +17,7 @@
using mlir::zamalang::ExecutionArgument;
using mlir::zamalang::JitCompilerEngine;
using mlir::zamalang::LambdaArgument;
/// Populate the compiler API python module.
void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
@@ -41,6 +42,27 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
return buildLambda(mlir_input.c_str(), func_name.c_str());
});
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
.def("is_tensor",
[](lambdaArgument &lambda_arg) {
return lambdaArgumentIsTensor(lambda_arg);
})
.def("get_tensor_data",
[](lambdaArgument &lambda_arg) {
return lambdaArgumentGetTensorData(lambda_arg);
})
.def("get_tensor_shape",
[](lambdaArgument &lambda_arg) {
return lambdaArgumentGetTensorDimensions(lambda_arg);
})
.def("is_scalar",
[](lambdaArgument &lambda_arg) {
return lambdaArgumentIsScalar(lambda_arg);
})
.def("get_scalar", [](lambdaArgument &lambda_arg) {
return lambdaArgumentGetScalar(lambda_arg);
});
pybind11::class_<JitCompilerEngine::Lambda>(m, "Lambda")
.def("invoke", [](JitCompilerEngine::Lambda &py_lambda,
std::vector<ExecutionArgument> args) {

View File

@@ -3,6 +3,8 @@ from typing import List, Union
from mlir._mlir_libs._zamalang._compiler import JitCompilerEngine as _JitCompilerEngine
from mlir._mlir_libs._zamalang._compiler import ExecutionArgument as _ExecutionArgument
from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip
import numpy as np
def round_trip(mlir_str: str) -> str:
"""Parse the MLIR input, then return it back.
@@ -67,7 +69,7 @@ class CompilerEngine:
raise TypeError("input must be an `str`")
self._lambda = self._engine.build_lambda(mlir_str, func_name)
def run(self, *args: List[Union[int, List[int]]]) -> int:
def run(self, *args: List[Union[int, List[int]]]) -> Union[int, np.array]:
"""Run the compiled code.
Args:
@@ -76,11 +78,20 @@ class CompilerEngine:
Raises:
TypeError: if execution arguments can't be constructed
RuntimeError: if the engine has not compiled any code yet
RuntimeError: if the return type is unknown
Returns:
int: result of execution.
int or numpy.array: result of execution.
"""
if self._lambda is None:
raise RuntimeError("need to compile an MLIR code first")
execution_arguments = [create_execution_argument(arg) for arg in args]
return self._lambda.invoke(execution_arguments)
lambda_arg = self._lambda.invoke(execution_arguments)
if lambda_arg.is_scalar():
return lambda_arg.get_scalar()
elif lambda_arg.is_tensor():
shape = lambda_arg.get_tensor_shape()
tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape)
return tensor
else:
raise RuntimeError("unknown return type")

View File

@@ -23,7 +23,7 @@ mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module,
return std::move(*lambdaOrErr);
}
uint64_t invokeLambda(lambda l, executionArguments args) {
lambdaArgument invokeLambda(lambda l, executionArguments args) {
mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr =
(mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr;
@@ -45,8 +45,12 @@ uint64_t invokeLambda(lambda l, executionArguments args) {
}
}
// Run lambda
llvm::Expected<uint64_t> resOrError = (*lambda_ptr)(
llvm::ArrayRef<mlir::zamalang::LambdaArgument *>(lambdaArgumentsRef));
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> resOrError =
(*lambda_ptr)
.
operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
llvm::ArrayRef<mlir::zamalang::LambdaArgument *>(
lambdaArgumentsRef));
// Free heap
for (size_t i = 0; i < lambdaArgumentsRef.size(); i++)
delete lambdaArgumentsRef[i];
@@ -58,7 +62,8 @@ uint64_t invokeLambda(lambda l, executionArguments args) {
<< llvm::toString(std::move(resOrError.takeError()));
throw std::runtime_error(os.str());
}
return *resOrError;
lambdaArgument result{std::move(*resOrError)};
return result;
}
std::string roundTrip(const char *module) {
@@ -80,3 +85,59 @@ std::string roundTrip(const char *module) {
retOrErr->mlirModuleRef->get().print(os);
return os.str();
}
bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) {
return lambda_arg.ptr->isa<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint64_t>>>();
}
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint64_t>> *arg =
lambda_arg.ptr->dyn_cast<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint64_t>>>();
if (arg == nullptr) {
throw std::invalid_argument(
"LambdaArgument isn't a tensor, should "
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
}
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
if (!sizeOrErr) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Couldn't get size of tensor: "
<< llvm::toString(std::move(sizeOrErr.takeError()));
throw std::runtime_error(os.str());
}
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
return data;
}
std::vector<int64_t>
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) {
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint64_t>> *arg =
lambda_arg.ptr->dyn_cast<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint64_t>>>();
if (arg == nullptr) {
throw std::invalid_argument(
"LambdaArgument isn't a tensor, should "
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
}
return arg->getDimensions();
}
bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) {
return lambda_arg.ptr->isa<mlir::zamalang::IntLambdaArgument<uint64_t>>();
}
uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) {
mlir::zamalang::IntLambdaArgument<uint64_t> *arg =
lambda_arg.ptr->dyn_cast<mlir::zamalang::IntLambdaArgument<uint64_t>>();
if (arg == nullptr) {
throw std::invalid_argument("LambdaArgument isn't a scalar, should "
"be an IntLambdaArgument<uint64_t>");
}
return arg->getValue();
}

View File

@@ -29,12 +29,28 @@ from zamalang import CompilerEngine
20,
id="dot_eint_int"
),
pytest.param(
"""
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
return %res : tensor<4x!HLFHE.eint<6>>
}
""",
([31, 6, 12, 9], [32, 9, 2, 3]),
[63, 15, 14, 12],
id="add_eint_int_1D"
),
],
)
def test_compile_and_run(mlir_input, args, expected_result):
engine = CompilerEngine()
engine.compile_fhe(mlir_input)
assert engine.run(*args) == expected_result
if isinstance(expected_result, int):
assert engine.run(*args) == expected_result
else:
# numpy array on the left
assert (engine.run(*args) == expected_result).all()