feat(compiler): Add method getResultVectorSize to JITLambda::Argument

Add method `JITLambda::Argument::getResultVectorSize` that returns the
number of elements of the result if the result is a vector.
This commit is contained in:
Andi Drebes
2021-10-18 11:04:48 +02:00
parent d4b4839d6e
commit b12be45143
2 changed files with 18 additions and 0 deletions

View File

@@ -57,6 +57,10 @@ public:
// Fill the result.
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
// Returns the number of elements of the result vector at position
// `pos` or an error if the result is a scalar value
llvm::Expected<size_t> getResultVectorSize(size_t pos);
private:
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);

View File

@@ -344,6 +344,20 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
return llvm::Error::success();
}
// Returns the number of elements of the result vector at position
// `pos` or an error if the result is a scalar value
llvm::Expected<size_t> JITLambda::Argument::getResultVectorSize(size_t pos) {
auto gate = outputGates[pos];
auto info = std::get<0>(gate);
if (info.shape.size == 0) {
return llvm::createStringError(llvm::inconvertibleErrorCode(),
"Result at pos %zu is not a tensor", pos);
}
return info.shape.size;
}
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
size_t size) {
auto gate = outputGates[pos];