From 56e261d1407691f743bef940ad99ecf2010e0ecb Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 5 Nov 2021 14:52:10 +0100 Subject: [PATCH] fix(compiler): pass dimensions at TensorLambdaArg creation --- .../zamalang/Support/ExecutionArgument.h | 47 ------------------- compiler/include/zamalang/Support/Jit.h | 4 ++ .../zamalang/Support/JitCompilerEngine.h | 7 ++- compiler/lib/Support/Jit.cpp | 15 ++++++ 4 files changed, 25 insertions(+), 48 deletions(-) delete mode 100644 compiler/include/zamalang/Support/ExecutionArgument.h diff --git a/compiler/include/zamalang/Support/ExecutionArgument.h b/compiler/include/zamalang/Support/ExecutionArgument.h deleted file mode 100644 index 44439cabc..000000000 --- a/compiler/include/zamalang/Support/ExecutionArgument.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H -#define ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H - -#include - -namespace mlir { -namespace zamalang { - -// Frontend object to abstract the different types of possible arguments, -// namely, integers, and tensors. -class ExecutionArgument { -public: - // There are two possible underlying types for the execution argument, either - // and int, or a tensor - bool isTensor() { return isTensorArg; } - bool isInt() { return !isTensorArg; } - - uint8_t *getTensorArgument() { return tensorArg.data(); } - - size_t getTensorSize() { return tensorArg.size(); } - - uint64_t getIntegerArgument() { return intArg; } - - // Create an execution argument from an integer - static std::shared_ptr create(uint64_t arg) { - return std::shared_ptr(new ExecutionArgument(arg)); - } - // Create an execution argument from a tensor - static std::shared_ptr create(std::vector arg) { - return std::shared_ptr(new ExecutionArgument(arg)); - } - -private: - ExecutionArgument(int arg) : isTensorArg(false), intArg(arg) {} - - ExecutionArgument(std::vector tensor) - : isTensorArg(true), tensorArg(tensor) {} - - uint64_t intArg; - std::vector tensorArg; - bool isTensorArg; -}; - -} // namespace zamalang -} // namespace mlir - -#endif diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index 6af5074e8..2382f5901 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -60,6 +60,10 @@ public: // `pos` or an error if the result is a scalar value llvm::Expected getResultVectorSize(size_t pos); + // Returns the dimensions of the result tensor at position `pos` or + // an error if the result is a scalar value + llvm::Expected> getResultDimensions(size_t pos); + private: llvm::Error setArg(size_t pos, size_t width, const void *data, llvm::ArrayRef shape); diff --git a/compiler/include/zamalang/Support/JitCompilerEngine.h b/compiler/include/zamalang/Support/JitCompilerEngine.h index 174d37878..ac2cef3ef 100644 --- a/compiler/include/zamalang/Support/JitCompilerEngine.h +++ b/compiler/include/zamalang/Support/JitCompilerEngine.h @@ -79,9 +79,14 @@ typedResult(JITLambda::Argument &arguments) { if (!tensorOrError) return std::move(tensorOrError.takeError()); + llvm::Expected> tensorDimOrError = + arguments.getResultDimensions(0); + if (!tensorDimOrError) + return std::move(tensorDimOrError.takeError()); + return std::move( std::make_unique>>( - *tensorOrError)); + *tensorOrError, *tensorDimOrError)); } } diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 2b4c65234..455ea7b96 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -345,6 +345,21 @@ llvm::Expected JITLambda::Argument::getResultVectorSize(size_t pos) { return info.shape.size; } +// Returns the dimensions of the result tensor at position `pos` or +// an error if the result is a scalar value +llvm::Expected> +JITLambda::Argument::getResultDimensions(size_t pos) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + + if (info.shape.size == 0) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Result at pos %zu is not a tensor", pos); + } + + return info.shape.dimensions; +} + llvm::Expected JITLambda::Argument::getResultType(size_t pos) { if (pos >= outputGates.size()) {