diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index c7358ae77..34212a8e7 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -57,6 +57,10 @@ public: // Fill the result. llvm::Error getResult(size_t pos, uint64_t *res, size_t size); + // Returns the number of elements of the result vector at position + // `pos` or an error if the result is a scalar value + llvm::Expected getResultVectorSize(size_t pos); + private: llvm::Error setArg(size_t pos, size_t width, void *data, size_t size); diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 95be53411..85ab88aaf 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -344,6 +344,20 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { return llvm::Error::success(); } +// Returns the number of elements of the result vector at position +// `pos` or an error if the result is a scalar value +llvm::Expected JITLambda::Argument::getResultVectorSize(size_t pos) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + + if (info.shape.size == 0) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Result at pos %zu is not a tensor", pos); + } + + return info.shape.size; +} + llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, size_t size) { auto gate = outputGates[pos];