feat(compiler): Add method JITLambda::Argument::getResultType(size_t)

Add a method `JITLambda::Argument::getResultType(size_t pos)` that
returns the type of the result with the index `pos` of a
`JITLambda::Argument`.
This commit is contained in:
Andi Drebes
2021-11-03 17:01:50 +01:00
committed by Ayoub Benaissa
parent 97ee8134ed
commit 1ad3d57f66
2 changed files with 26 additions and 0 deletions

View File

@@ -41,6 +41,13 @@ public:
// Get the result at the given pos as an uint64_t.
llvm::Error getResult(size_t pos, uint64_t &res);
// Specifies the type of a result
enum ResultType { SCALAR, TENSOR };
// Returns the result type at position `pos`. If pos is invalid,
// an error is returned.
llvm::Expected<enum ResultType> getResultType(size_t pos);
// Get a result for tensors, fill the `res` buffer with the value of the
// tensor result.
// Returns an error:

View File

@@ -344,6 +344,25 @@ llvm::Expected<size_t> JITLambda::Argument::getResultVectorSize(size_t pos) {
return info.shape.size;
}
llvm::Expected<enum JITLambda::Argument::ResultType>
JITLambda::Argument::getResultType(size_t pos) {
if (pos >= outputGates.size()) {
return llvm::createStringError(llvm::inconvertibleErrorCode(),
"Requesting type for result at index %zu, "
"but lambda only generates %zu results",
pos, outputGates.size());
}
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
if (info.shape.size == 0) {
return ResultType::SCALAR;
} else {
return ResultType::TENSOR;
}
}
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
size_t size) {