From b12be451433a1687d2343275e7bb88cc2fc15931 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 18 Oct 2021 11:04:48 +0200 Subject: [PATCH] feat(compiler): Add method getResultVectorSize to JITLambda::Argument Add method `JITLambda::Argument::getResultVectorSize` that returns the number of elements of the result if the result is a vector. --- compiler/include/zamalang/Support/Jit.h | 4 ++++ compiler/lib/Support/Jit.cpp | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index c7358ae77..34212a8e7 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -57,6 +57,10 @@ public: // Fill the result. llvm::Error getResult(size_t pos, uint64_t *res, size_t size); + // Returns the number of elements of the result vector at position + // `pos` or an error if the result is a scalar value + llvm::Expected getResultVectorSize(size_t pos); + private: llvm::Error setArg(size_t pos, size_t width, void *data, size_t size); diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 95be53411..85ab88aaf 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -344,6 +344,20 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { return llvm::Error::success(); } +// Returns the number of elements of the result vector at position +// `pos` or an error if the result is a scalar value +llvm::Expected JITLambda::Argument::getResultVectorSize(size_t pos) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + + if (info.shape.size == 0) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Result at pos %zu is not a tensor", pos); + } + + return info.shape.size; +} + llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, size_t size) { auto gate = outputGates[pos];