feat(compiler): Add generic overload for result of JitCompilerEngine::Lambda

This adds a new overload for `JitCompilerEngine::Lambda::operator()`,
returning a result wrapped in a `std::unique_ptr<LambdaArgument>` with
meta information about the result. This allows for generic invocations
of JitCompilerEngine::Lambda::operator(), where the result type is
unknown before the invocation.
This commit is contained in:
Andi Drebes
2021-11-03 17:03:16 +01:00
committed by Ayoub Benaissa
parent 1ad3d57f66
commit 9040e5ab00

View File

@@ -51,6 +51,43 @@ typedResult(JITLambda::Argument &arguments) {
return std::move(res);
}
// Specialization of `typedResult()` for a single result wrapped into
// a `LambdaArgument`.
template <>
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
typedResult(JITLambda::Argument &arguments) {
llvm::Expected<enum JITLambda::Argument::ResultType> resTy =
arguments.getResultType(0);
if (!resTy)
return std::move(resTy.takeError());
switch (*resTy) {
case JITLambda::Argument::ResultType::SCALAR: {
uint64_t res;
if (llvm::Error err = arguments.getResult(0, res))
return std::move(err);
return std::move(std::make_unique<IntLambdaArgument<uint64_t>>(res));
}
case JITLambda::Argument::ResultType::TENSOR: {
llvm::Expected<std::vector<uint64_t>> tensorOrError =
typedResult<std::vector<uint64_t>>(arguments);
if (!tensorOrError)
return std::move(tensorOrError.takeError());
return std::move(
std::make_unique<TensorLambdaArgument<IntLambdaArgument<uint64_t>>>(
*tensorOrError));
}
}
return StreamStringError("Unknown result type");
}
// Adaptor class that adds arguments specified as instances of
// `LambdaArgument` to `JitLambda::Argument`.
class JITLambdaArgumentAdaptor {