enhance(compiler): JitCompilerEngine: Add support for arbitrary int result tensors

Add support for result tensors composed of uint8_t, uint16_t, uint32_t
and uint64_t elements, replacing the current implementation, which
only supports uint64_t tensors.
This commit is contained in:
Andi Drebes
2021-11-09 14:23:52 +01:00
committed by Quentin Bourgerie
parent 3118983287
commit f54c0dd8d8

View File

@@ -32,18 +32,15 @@ inline llvm::Expected<uint64_t> typedResult(JITLambda::Argument &arguments) {
return res;
}
// Specialization of `typedResult()` for vector results, initializing
// an `std::vector` of the right size with the results and forwarding
// it to the caller with move semantics.
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(JITLambda::Argument &arguments) {
template <typename T>
inline llvm::Expected<std::vector<T>>
typedVectorResult(JITLambda::Argument &arguments) {
llvm::Expected<size_t> n = arguments.getResultVectorSize(0);
if (auto err = n.takeError())
return std::move(err);
std::vector<uint64_t> res(*n);
std::vector<T> res(*n);
if (auto err = arguments.getResult(0, res.data(), res.size()))
return StreamStringError() << "Cannot retrieve result:" << err;
@@ -51,6 +48,54 @@ typedResult(JITLambda::Argument &arguments) {
return std::move(res);
}
// Specializations of `typedResult()` for vector results, initializing
// an `std::vector` of the right size with the results and forwarding
// it to the caller with move semantics.
//
// Cannot factor out into a template template <typename T> inline
// llvm::Expected<std::vector<uint8_t>>
// typedResult(JITLambda::Argument &arguments); due to ambiguity with
// scalar template
template <>
inline llvm::Expected<std::vector<uint8_t>>
typedResult(JITLambda::Argument &arguments) {
return std::move(typedVectorResult<uint8_t>(arguments));
}
template <>
inline llvm::Expected<std::vector<uint16_t>>
typedResult(JITLambda::Argument &arguments) {
return std::move(typedVectorResult<uint16_t>(arguments));
}
template <>
inline llvm::Expected<std::vector<uint32_t>>
typedResult(JITLambda::Argument &arguments) {
return std::move(typedVectorResult<uint32_t>(arguments));
}
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(JITLambda::Argument &arguments) {
return std::move(typedVectorResult<uint64_t>(arguments));
}
template <typename T>
llvm::Expected<std::unique_ptr<LambdaArgument>>
buildTensorLambdaResult(JITLambda::Argument &arguments) {
llvm::Expected<std::vector<T>> tensorOrError =
typedResult<std::vector<T>>(arguments);
if (!tensorOrError)
return std::move(tensorOrError.takeError());
llvm::Expected<std::vector<int64_t>> tensorDimOrError =
arguments.getResultDimensions(0);
if (!tensorDimOrError)
return std::move(tensorDimOrError.takeError());
return std::move(std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
*tensorOrError, *tensorDimOrError));
}
// Specialization of `typedResult()` for a single result wrapped into
// a `LambdaArgument`.
template <>
@@ -73,25 +118,26 @@ typedResult(JITLambda::Argument &arguments) {
}
case JITLambda::Argument::ResultType::TENSOR: {
llvm::Expected<std::vector<uint64_t>> tensorOrError =
typedResult<std::vector<uint64_t>>(arguments);
llvm::Expected<size_t> width = arguments.getResultWidth(0);
if (!tensorOrError)
return std::move(tensorOrError.takeError());
if (!width)
return std::move(width.takeError());
llvm::Expected<std::vector<int64_t>> tensorDimOrError =
arguments.getResultDimensions(0);
if (!tensorDimOrError)
return std::move(tensorDimOrError.takeError());
return std::move(
std::make_unique<TensorLambdaArgument<IntLambdaArgument<uint64_t>>>(
*tensorOrError, *tensorDimOrError));
if (*width > 64)
return StreamStringError("Cannot handle scalars with more than 64 bits");
if (*width > 32)
return buildTensorLambdaResult<uint64_t>(arguments);
else if (*width > 16)
return buildTensorLambdaResult<uint32_t>(arguments);
else if (*width > 8)
return buildTensorLambdaResult<uint16_t>(arguments);
else if (*width <= 8)
return buildTensorLambdaResult<uint8_t>(arguments);
}
}
return StreamStringError("Unknown result type");
}
} // namespace
// Adaptor class that adds arguments specified as instances of
// `LambdaArgument` to `JitLambda::Argument`.