From 1ad3d57f66048737a9b3ecd13f571336172db7ed Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 3 Nov 2021 17:01:50 +0100 Subject: [PATCH] feat(compiler): Add method JITLambda::Argument::getResultType(size_t) Add a method `JITLambda::Argument::getResultType(size_t pos)` that returns the type of the result with the index `pos` of a `JITLambda::Argument`. --- compiler/include/zamalang/Support/Jit.h | 7 +++++++ compiler/lib/Support/Jit.cpp | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index 0129739ff..a6949b694 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -41,6 +41,13 @@ public: // Get the result at the given pos as an uint64_t. llvm::Error getResult(size_t pos, uint64_t &res); + // Specifies the type of a result + enum ResultType { SCALAR, TENSOR }; + + // Returns the result type at position `pos`. If pos is invalid, + // an error is returned. + llvm::Expected getResultType(size_t pos); + // Get a result for tensors, fill the `res` buffer with the value of the // tensor result. // Returns an error: diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 81964e4d1..91ab0b826 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -344,6 +344,25 @@ llvm::Expected JITLambda::Argument::getResultVectorSize(size_t pos) { return info.shape.size; } +llvm::Expected +JITLambda::Argument::getResultType(size_t pos) { + if (pos >= outputGates.size()) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Requesting type for result at index %zu, " + "but lambda only generates %zu results", + pos, outputGates.size()); + } + + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + + if (info.shape.size == 0) { + return ResultType::SCALAR; + } else { + return ResultType::TENSOR; + } +} + llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, size_t size) {